In [3]:
# requires to install eofs and gpytorch
import xarray as xr
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import gpytorch
import os
import glob
from eofs.xarray import Eof

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from typing import Dict, Optional, List, Callable, Tuple, Union

import wandb
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

Notes:
- As in ClimateBench, we use EOF as a dim reduction technique on the aerosols input (BC and SO2) and keeps only the 5 first modes.  
- We use a different version of the Gaussian Process (GP) from the ClimateBench paper. We use a stochastic variational variant (SVGP; see https://arxiv.org/pdf/1411.2005.pdf) that supports training with minibatches and can scale to large datasets. It relies on the library GPytorch (https://gpytorch.ai/ and https://arxiv.org/pdf/1809.11165.pdf), a Pytorch implementation of gaussian processes.
- For spatial output, the Linear Model of Coregionalization (LMC) seems to be the best multitask model (run in a reasonable amounts of time; see https://docs.gpytorch.ai/en/stable/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.html). The number of latents used is an hyperparameter than control the capacity of the model.

### Train and test dataloading

In [4]:
input_dir = '/home/mila/v/venkatesh.ramesh/scratch/causal_data/inputs/input4mips'
target_dir = '/home/mila/v/venkatesh.ramesh/scratch/causal_data/targets/CMIP6'

models = ['NorESM2-LM']
fire_type = 'all-fires'
variables = ['pr']
train_experiments = ["ssp585", "ssp126", "ssp370"] 
test_experiments = ["ssp245"]
input_gases = ['BC_sum', 'CH4_sum', 'CO2_sum', 'SO2_sum']
total_ensembles = 1 #-1 for all

In [7]:
def load_train_data(mode: str = 'train'):
    X, (so2_solver, bc_solver) = get_input_data(input_dir, mode)
    y = get_output_data(target_dir, mode)
    return torch.tensor(X), torch.tensor(y), (so2_solver, bc_solver)


def load_test_data(mode: str = 'train', solvers = None):
    X, (so2_solver, bc_solver) = get_input_data(input_dir, mode, solvers)
    y = get_output_data(target_dir, mode)
    return torch.tensor(X), torch.tensor(y), (so2_solver, bc_solver)


def load_data_npz(path: str): #If np data already exists
    X_train, y_train = np.load(os.path.join(base_dir, ''))
    X_test, y_test = np.load(os.path.join(base_dir, ''))
    return X_train, y_train, X_test, y_test


def get_input_data(path: str, mode: str, solvers = None, n_eofs : int = 5):
    BC = []
    CH4 = []
    CO2 = []
    SO2 = []
    
    if mode == 'train':      
        experiments = train_experiments
    elif mode == 'test':
        experiments = test_experiments
        
    for exp in experiments:
        for gas in input_gases:
            var_dir = os.path.join(path, exp, gas, 'map_250_km/mon')
            files = glob.glob(var_dir + '/**/*.nc', recursive=True)

            for f in files:
                if gas == 'BC_sum' and fire_type in f:
                    BC.append(f)
            for f in files:
                if gas == 'CH4_sum' and fire_type in f:
                    CH4.append(f)
            for f in files:
                if gas == 'BC_sum' and fire_type in f:
                    SO2.append(f)
            for f in files:
                if gas == 'CO2_sum':
                    CO2.append(f)

    BC_data = xr.open_mfdataset(BC, concat_dim='time', combine='nested').compute().to_array()  # .to_numpy()
    SO2_data = xr.open_mfdataset(SO2, concat_dim='time', combine='nested').compute() .to_array()  #.to_numpy()
    CH4_data = xr.open_mfdataset(CH4, concat_dim='time', combine='nested').compute().to_array().to_numpy()
    CO2_data = xr.open_mfdataset(CO2, concat_dim='time', combine='nested').compute().to_array().to_numpy()
    
    # BC_data = np.moveaxis(BC_data, 0, 1)
    # SO2_data = np.moveaxis(SO2_data, 0, 1)
    CH4_data = np.moveaxis(CH4_data, 0, 1)
    CO2_data = np.moveaxis(CO2_data, 0, 1)
    CH4_data = CH4_data.reshape(CH4_data.shape[0], -1)
    CO2_data = CH4_data.reshape(CO2_data.shape[0], -1)

    
    BC_data = BC_data.transpose('time', 'variable', 'lat', 'lon')
    SO2_data = SO2_data.transpose('time', 'variable', 'lat', 'lon')
    BC_data = BC_data.assign_coords(time=np.arange(len(BC_data.time)))
    SO2_data = SO2_data.assign_coords(time=np.arange(len(SO2_data.time)))

    
    # Compute EOFs for BC
    if solvers is None:
        # print(BC_data.shape)
        bc_solver = Eof(BC_data)
        bc_eofs = bc_solver.eofsAsCorrelation(neofs=n_eofs)
        bc_pcs = bc_solver.pcs(npcs=n_eofs, pcscaling=1)

        # Compute EOFs for SO2
        so2_solver = Eof(SO2_data)
        so2_eofs = so2_solver.eofsAsCorrelation(neofs=n_eofs)
        so2_pcs = so2_solver.pcs(npcs=n_eofs, pcscaling=1)

        print(bc_pcs)

        # Convert to pandas
        bc_df = bc_pcs.to_dataframe().unstack('mode')
        bc_df.columns = [f"BC_{i}" for i in range(n_eofs)]

        so2_df = so2_pcs.to_dataframe().unstack('mode')
        so2_df.columns = [f"SO2_{i}" for i in range(n_eofs)]
    else:
        so2_solver = solvers[0]
        bc_solver = solvers[1]
        
        so2_pcs = so2_solver.projectField(SO2_data, neofs=n_eofs, eofscaling=1)
        so2_df = so2_pcs.to_dataframe().unstack('mode')
        so2_df.columns = [f"SO2_{i}" for i in range(n_eofs)]

        bc_pcs = bc_solver.projectField(BC_data, neofs=n_eofs, eofscaling=1)
        bc_df = bc_pcs.to_dataframe().unstack('mode')
        bc_df.columns = [f"BC_{i}" for i in range(n_eofs)]
    
    CH4_data = CH4_data[:, :1]
    CO2_data = CO2_data[:, :1]

    print(bc_df.shape)
    print(CH4_data.shape)
    print(CO2_data.shape)
    print(so2_df.shape)
    
    merged_data = np.concatenate((bc_df, CH4_data, CO2_data, so2_df), axis=1)
    return merged_data, (so2_solver, bc_solver)


def get_output_data(path: str, mode: str):
    nc_files = []
    
    if mode == 'train':
        experiments = train_experiments
    elif mode == 'test':
        experiments = test_experiments
        
    for mod in models:

        model_dir = os.path.join(path, mod)
        ensembles = os.listdir(model_dir)

        if total_ensembles == 1:
            ensembles = ensembles[0]
        
        exp_counter = 0
        for exp in experiments:
            for var in variables:
                var_dir = os.path.join(path, mod, ensembles, exp, var, '250_km/mon')
                files = glob.glob(var_dir + '/**/*.nc', recursive=True)
                nc_files += files
        
            if exp_counter == 0:
                dataset = xr.open_mfdataset(nc_files).compute().to_array().to_numpy()
        
            else: #concatenate dataset in time dimension
                other_experiment = xr.open_mfdataset(nc_files).compute().to_array().to_numpy()
                dataset = np.concatenate((dataset, other_experiment), axis=1)
                
                
            exp_counter += 1
            
        dataset = np.moveaxis(dataset, 0, 1)
        print(dataset.shape)
        dataset = dataset.reshape(dataset.shape[0], -1)
        
        # TODO: remove next line, only used for making quick tests
        dataset = dataset[:, :1]
    
    return dataset

In [8]:
X_train, y_train, (so2_solver, bc_solver) = load_train_data('train')

<xarray.DataArray 'pcs' (time: 3096, mode: 5)>
array([[-1.3052769 , -1.2704228 ,  0.41740844, -1.1416101 , -2.3582041 ],
       [-1.1239995 , -0.90367   ,  0.21111041, -1.577814  , -0.45786643],
       [-0.93241996, -0.81854814,  0.19349629, -1.7038045 ,  0.5746188 ],
       ...,
       [-0.7419013 ,  0.5875077 ,  1.0679767 ,  1.3232024 ,  1.6523781 ],
       [-0.6475205 , -0.62544924,  0.31076536,  1.3729664 ,  1.119416  ],
       [-0.67845637, -0.8426438 ,  0.26810127,  1.4787649 , -0.32489872]],
      dtype=float32)
Coordinates:
  * time     (time) int64 0 1 2 3 4 5 6 7 ... 3089 3090 3091 3092 3093 3094 3095
  * mode     (mode) int64 0 1 2 3 4
(3096, 5)
(3096, 1)
(3096, 1)
(3096, 5)
(3096, 1, 96, 144)


In [9]:
X_test, y_test, (so2_solver, bc_solver) = load_test_data('test', (so2_solver, bc_solver))

(1032, 5)
(1032, 1)
(1032, 1)
(1032, 5)
(1032, 1, 96, 144)


In [10]:
# just for test:
class ClimateDataset(Dataset):
    def __init__(self, X, y):
        # global
        self.X = X
        self.y = y
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

training_data = ClimateDataset(X_train, y_train)
test_data = ClimateDataset(X_test, y_test)

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [7]:
# alternative dataset, just used for test:
class DummyDataset(Dataset):
    def __init__(self, n=3000):
        # global
        self.X = torch.rand([n, 1440])  # 4
        self.y = torch.rand([n, 1440])  # 2
        
        # spatial
        # self.X = torch.rand([n, 4 * 96 * 144])
        # self.y = torch.rand([n, 1 * 96 * 144])
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# training_data = DummyDataset(n=500)
# test_data = DummyDataset(n=100)

# train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

### Gaussian Process Model (GP)

In [11]:
# hyperparameters
num_inducing_points = 500
n_epochs = 2  # Could use criterion to stop
lr = 0.1
# optimizer = adam
# kernel = matern3/2

class ApproxGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points, num_tasks):
        # inducing_points size: num_outputs, num_examples, num_features
        inducing_points = inducing_points.reshape(1, inducing_points.size(0), -1)
        # inducing_points = inducing_points.repeat(num_tasks, 1, 1)
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(-2), batch_shape=torch.Size([num_tasks]))

        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks
        )

        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_tasks]))

        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=1.5, batch_shape=torch.Size([num_tasks])), batch_shape=torch.Size([num_tasks]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class GaussianProcess(nn.Module):
    def __init__(self,
                 inducing_points,
                 num_out_var):
        super().__init__()
        self.model = ApproxGPModel(inducing_points=inducing_points, num_tasks=num_out_var)
        self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_out_var)

    def forward(self, x):
        x = x.reshape(x.size(0), -1)
        return self.model(x)

    def predict(self, x):
        predictions = self.likelihood(self.model(X))
        return predictions.mean

In [12]:
training_data.X[:num_inducing_points].shape

torch.Size([500, 12])

In [13]:
gp_model = GaussianProcess(training_data.X[:num_inducing_points], training_data.y.size(1))

### GP: Training

In [14]:
optimizer = torch.optim.Adam([
    {'params': gp_model.parameters()}
], lr=lr)
mll = gpytorch.mlls.VariationalELBO(gp_model.likelihood, gp_model.model, num_data=training_data.y.size(0))
gp_model.train()

GaussianProcess(
  (model): ApproxGPModel(
    (variational_strategy): IndependentMultitaskVariationalStrategy(
      (base_variational_strategy): VariationalStrategy(
        (_variational_distribution): CholeskyVariationalDistribution()
      )
    )
    (mean_module): ConstantMean()
    (covar_module): ScaleKernel(
      (base_kernel): MaternKernel(
        (raw_lengthscale_constraint): Positive()
      )
      (raw_outputscale_constraint): Positive()
    )
  )
  (likelihood): MultitaskGaussianLikelihood(
    (raw_task_noises_constraint): GreaterThan(1.000E-04)
    (raw_noise_constraint): GreaterThan(1.000E-04)
  )
)

In [15]:
n_epochs = 1

In [16]:
for i in range(n_epochs):
    print(f"epoch #{i}")
    for x, y in train_dataloader:
        optimizer.zero_grad()
        output = gp_model(x)
        loss = -mll(output, y)
        loss.backward()
        optimizer.step()

torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([64, 1])
torch.Size([24, 1])


### GP: Evaluation

In [17]:
# Evaluate
gp_model.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var():
    predictions = gp_model.likelihood(gp_model(test_data.X))
    y_pred = predictions.mean
    # lower, upper = predictions.confidence_region()  # could have confidence regions if necessary

In [18]:
# y_pred = y_pred.reshape(3096, 1, 96, 144)
rmse = mean_squared_error(test_data.y, y_pred, squared=False)
rmse

0.005023047