In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from sklearn.model_selection import train_test_split

import xarray as xr
import os
import torch
from functools import reduce 

torch.device("cuda" if torch.cuda.is_available() else "cpu")

LOW_RES_SAMPLE_PATH = "data/ClimSim_low-res/train/"
LOW_RES_GRID_PATH = "data/ClimSim_low-res/ClimSim_low-res_grid-info.nc"
ZARR_PATH = "data/ClimSim_low-res.zarr"

In [117]:
class ClimSimMLP(nn.Module):
    def __init__(self, input_dim=556, output_tendancies_dim=120, output_surface_dim=8):
        super(ClimSimMLP, self).__init__()
        
        # Hidden Layers: [768, 640, 512, 640, 640]
        self.layer1 = nn.Linear(input_dim, 768)
        self.layer2 = nn.Linear(768, 640)
        self.layer3 = nn.Linear(640, 512)
        self.layer4 = nn.Linear(512, 640)
        self.layer5 = nn.Linear(640, 640)
        

        self.last_hidden = nn.Linear(640, 128)
        
        # --- Output Heads ---
        self.head_tendencies = nn.Linear(128, output_tendancies_dim)
        self.head_surface = nn.Linear(128, output_surface_dim)
        
        # LeakyReLU alpha=0.15
        self.activation = nn.LeakyReLU(0.15)

    def forward(self, x):
        # Pass through the 5 main hidden layers
        x = self.activation(self.layer1(x))
        x = self.activation(self.layer2(x))
        x = self.activation(self.layer3(x))
        x = self.activation(self.layer4(x))
        x = self.activation(self.layer5(x))
        
        # Pass through the fixed 128 layer
        x = self.activation(self.last_hidden(x))
        
        # Output 1: Tendencies (Linear activation)
        out_linear = self.head_tendencies(x)
        
        # Output 2: Surface variables (ReLU activation)
        out_relu = F.relu(self.head_surface(x))
        
        # Concatenate along the feature dimension (dim=1)
        return torch.cat([out_linear, out_relu], dim=1)

        return out_linear

@torch.no_grad()
def evaluate_model(model, dataloader, criterion, device, input_dim, output_dim):
    model.eval()
    total_loss = 0.0
    total_samples = 0

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        inputs = inputs.view(-1, input_dim)  # Allow flattening for MLP
        targets = targets.view(-1, output_dim) 

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        batch_size = inputs.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size

    average_loss = total_loss / total_samples
    return average_loss

def train_one_epoch(model, dataloader, optimizer, criterion, device, input_dim, output_dim):
    model.train()
    total_loss = 0.0
    total_samples = 0
    
    pbar = tqdm(dataloader, desc="Training", unit="batch")

    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)

        inputs = inputs.view(-1, input_dim)  # Allow flattening for MLP
        targets = targets.view(-1, output_dim) 
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        batch_size = inputs.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size
        
        # Update progress bar description with current loss
        pbar.set_postfix({"loss": f"{loss.item():.6f}"})

    return total_loss / total_samples

In [115]:
import torch
import xarray as xr
import numpy as np
from torch.utils.data import Dataset

class ClimSimZarrDataset(Dataset):
    def __init__(self, zarr_path, grid_path, features, transform=None):
        self.zarr_path = zarr_path

        self.features = features
        self.features_list = self.__get_features__()

        self.ds = xr.open_zarr(zarr_path)[self.features_list].isel(sample=slice(0, 1000)) # for testing purpose
        self.grid = xr.open_dataset(grid_path)
        self.transform = transform
        
        self.length = self.ds.dims['sample']
        
        self.input_vars = [v for v in self.ds.data_vars if 'in' in v]
        self.output_vars = [v for v in self.ds.data_vars if 'out' in v]

    def __len__(self):
        return self.length
    
    def __get_features__(self):
        feat = np.concat([self.features["features"]["tendancies"], self.features["features"]["surface"]])
        target = np.concat([self.features["target"]["tendancies"], self.features["target"]["surface"]])
        return np.concat([feat, target])

    def __getitem__(self, idx):
        sample = self.ds.isel(sample=idx).load()

        def prepare_data(vars_list): # to put all data in (ncol, nfeatures) format to concatenate them
            output_list = []
            for var in vars_list:
                data = sample[var].values # Peut être (60, 384) ou (384,)            
                if data.ndim == 2:
                    data = data.T # To make sure all start by (ncol,)
                else:
                    # C'est une variable de surface (ncol,) -> on veut (ncol, 1)
                    data = data[:, np.newaxis]
                
                output_list.append(data)
        
            # Now all data is (ncol, nfeatures) format
            return np.concatenate(output_list, axis=1).astype(np.float32)

        x = prepare_data(self.input_vars)  # Résultat: (384, 246)
        y = prepare_data(self.output_vars) # Résultat: (384, 61)
            
        return torch.from_numpy(x), torch.from_numpy(y)
    
    def get_models_dims(self, variables_dict):
        features_tend = variables_dict["features"]["tendancies"]
        features_surf = variables_dict["features"]["surface"]
        
        target_tend = variables_dict["target"]["tendancies"]
        target_surf = variables_dict["target"]["surface"]

        def get_var_dim(var):
            if 'lev' in self.ds[var].dims:
                return self.ds[var].sizes['lev']
            return 1

        in_tend_dim = sum([get_var_dim(var) for var in features_tend])
        in_surf_dim = len(features_surf)
        
        out_tend_dim = sum([get_var_dim(var) for var in target_tend])
        out_surf_dim = len(target_surf)

        return {
            "input_total": in_tend_dim + in_surf_dim,
            "output_tendancies": out_tend_dim,
            "output_surface": out_surf_dim
        }

In [119]:
BATCH_SIZE = 100
N_EPOCHS = 10

FEATURES = {
    "features" :{
        "tendancies" : ["in_state_t", "in_state_q0001", "in_state_u", "in_state_v"],
        "surface" : ["in_pbuf_COSZRS", "in_pbuf_LHFLX", "in_pbuf_SHFLX", "in_pbuf_TAUX", "in_pbuf_TAUY", "in_pbuf_SOLIN"],
    },  
    "target" :{
        "tendancies" : ["out_state_t"],
        "surface" : ["out_cam_out_SOLL"]
    }
}

dataset = ClimSimZarrDataset(ZARR_PATH, LOW_RES_GRID_PATH, FEATURES)

model_dims = dataset.get_models_dims(FEATURES)

model = ClimSimMLP(input_dim=model_dims["input_total"], output_tendancies_dim=model_dims["output_tendancies"], output_surface_dim=model_dims["output_surface"])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

  self.length = self.ds.dims['sample']


In [79]:
train, test = train_test_split(dataset,  test_size=0.2, random_state=42)

train_loader = torch.utils.data.DataLoader(
    train, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=0,
)

test_loader = torch.utils.data.DataLoader(
    test, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=0,
)

In [121]:
for epoch in range(N_EPOCHS):
    train_loss = train_one_epoch(
        model, 
        train_loader, 
        optimizer, 
        criterion, 
        device="cpu",
        input_dim=model_dims["input_total"],
        output_dim=model_dims["output_tendancies"] + model_dims["output_surface"],
        )
    val_loss = evaluate_model(
        model, 
        test_loader, 
        criterion, 
        device="cpu",
        input_dim=model_dims["input_total"],
        output_dim=model_dims["output_tendancies"] + model_dims["output_surface"],
        )
    
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

Training: 100%|██████████| 8/8 [00:13<00:00,  1.67s/batch, loss=10208.026367]


Epoch 1, Train Loss: 22858.4716, Val Loss: 15851.1450


Training: 100%|██████████| 8/8 [00:13<00:00,  1.66s/batch, loss=2374.431885] 


Epoch 2, Train Loss: 6074.3344, Val Loss: 2859.3103


Training: 100%|██████████| 8/8 [00:14<00:00,  1.75s/batch, loss=1438.586914]


Epoch 3, Train Loss: 2390.0921, Val Loss: 747.7397


Training: 100%|██████████| 8/8 [00:14<00:00,  1.84s/batch, loss=950.642395] 


Epoch 4, Train Loss: 869.4419, Val Loss: 771.9193


Training: 100%|██████████| 8/8 [00:15<00:00,  1.89s/batch, loss=486.196075]


Epoch 5, Train Loss: 552.4299, Val Loss: 540.0049


Training: 100%|██████████| 8/8 [00:14<00:00,  1.84s/batch, loss=344.856689]


Epoch 6, Train Loss: 408.3157, Val Loss: 397.2182


Training: 100%|██████████| 8/8 [00:14<00:00,  1.81s/batch, loss=302.930115]


Epoch 7, Train Loss: 341.2494, Val Loss: 327.2119


Training: 100%|██████████| 8/8 [00:13<00:00,  1.71s/batch, loss=294.565369]


Epoch 8, Train Loss: 309.6892, Val Loss: 306.7909


Training: 100%|██████████| 8/8 [00:14<00:00,  1.86s/batch, loss=287.457428]


Epoch 9, Train Loss: 296.2079, Val Loss: 295.5815


Training: 100%|██████████| 8/8 [00:16<00:00,  2.06s/batch, loss=287.965729]


Epoch 10, Train Loss: 289.6295, Val Loss: 290.8185
