In [1]:
from robust_analysis import train_ridge_regression, train_robust_model, compute_weights,\
                            leave_one_out, leave_one_out_procedure, cross_validation_loo

import pickle
import os 
import netCDF4 as netcdf
import skimage
import numpy as np
import torch 

with open('ssp585_time_series.pkl', 'rb') as f:
    dic_ssp585 = pickle.load(f)

# Get the list of all files and directories
path = "/net/atmos/data/cmip6-ng/tos/ann/g025"
dir_list = os.listdir(path)

print("Files and directories in '", path, "' :")

list_model = []
list_forcing = []

for idx, file in enumerate(dir_list):

    file_split = file.split("_")
    
    # extract model names
    model_name = file_split[2]
    forcing = file_split[3]
    run_name = file_split[4]
    
    list_model.append(model_name)
    list_forcing.append(forcing)
    
model_names = list(set(list_model))
forcing_names = list(set(list_forcing))


# define the file
file = '/net/h2o/climphys3/simondi/cope-analysis/data/erss/sst_annual_g050_mean_19812014_centered.nc'

# read the dataset
file2read = netcdf.Dataset(file,'r')

# load longitude, latitude and sst monthly means
lon = np.array(file2read.variables['lon'][:])
lat = np.array(file2read.variables['lat'][:])
sst = np.array(file2read.variables['sst'])

# define grid
lat_grid, lon_grid = np.meshgrid(lat, lon, indexing='ij')

time_period = 33
grid_lat_size = lat.shape[0]
grid_lon_size = lon.shape[0]

# first filter out the models that do not contain ensemble members 
dic_reduced_ssp585 = {}

for m in list(dic_ssp585.keys()):
    if len(dic_ssp585[m].keys()) > 2:
        dic_reduced_ssp585[m] = dic_ssp585[m].copy()
        for idx_i, i in enumerate(dic_ssp585[m].keys()):
            dic_reduced_ssp585[m][i] = skimage.transform.downscale_local_mean(dic_reduced_ssp585[m][i],(1,2,2))
            lat_size = dic_reduced_ssp585[m][i][0,:,:].shape[0]
            lon_size = dic_reduced_ssp585[m][i][0,:,:].shape[1]

######## Store Nan indices 

nan_idx = []
for idx_m,m in enumerate(dic_reduced_ssp585.keys()):
    for idx_i,i in enumerate(dic_reduced_ssp585[m].keys()):    

        nan_idx_tmp = list(np.where(np.isnan(dic_reduced_ssp585[m][i][0,:,:].ravel())==True)[0])        
        nan_idx = list(set(nan_idx) | set(nan_idx_tmp))

notnan_idx = list(set(list(range(lon_size*lat_size))) - set(nan_idx))

############################


# second, for each model we compute the anomalies 
dic_processed_ssp585 = {}

import numpy as np

for idx_m,m in enumerate(dic_reduced_ssp585.keys()):
    dic_processed_ssp585[m] = dic_reduced_ssp585[m].copy()
    
    mean_ref_ensemble = 0
    for idx_i, i in enumerate(dic_reduced_ssp585[m].keys()):
        y_tmp = dic_reduced_ssp585[m][i][131:164,:,:].copy().reshape(time_period, lat_size*lon_size)
        y_tmp[:,nan_idx] = float('nan')

        if idx_i == 0:
            mean_ref_ensemble = np.nanmean(y_tmp,axis=0)/ len(dic_processed_ssp585[m].keys())
        else:
            mean_ref_ensemble += np.nanmean(y_tmp,axis=0)/ len(dic_processed_ssp585[m].keys())

    for idx_i, i in enumerate(dic_processed_ssp585[m].keys()):
        dic_processed_ssp585[m][i] = y_tmp - mean_ref_ensemble


# compute the forced response
dic_forced_response_ssp585 = dict({})

for idx_m,m in enumerate(dic_reduced_ssp585.keys()):
    dic_forced_response_ssp585[m] = dic_reduced_ssp585[m].copy()

    for idx_i, i in enumerate(dic_forced_response_ssp585[m].keys()):
        
        y_tmp = dic_reduced_ssp585[m][i][131:164,:,:].copy().reshape(time_period, lat_size*lon_size)
        y_tmp[:,nan_idx] = float('nan')
        
        if idx_i == 0:
            mean_spatial_ensemble = np.nanmean(y_tmp,axis=1)/ len(dic_forced_response_ssp585[m].keys())
        else:
            mean_spatial_ensemble += np.nanmean(y_tmp,axis=1)/ len(dic_forced_response_ssp585[m].keys())

    for idx_i, i in enumerate(dic_forced_response_ssp585[m].keys()):        
        dic_forced_response_ssp585[m][i] = mean_spatial_ensemble - np.nanmean(mean_spatial_ensemble)

y_forced_response = {}
x_predictor = {}

for idx_m,m in enumerate(dic_processed_ssp585.keys()):
    y_forced_response[m] = {}
    x_predictor[m] = {}

    for idx_i, i in enumerate(dic_forced_response_ssp585[m].keys()):       
        y_forced_response[m][i] = dic_forced_response_ssp585[m][i]
        x_predictor[m][i] = dic_processed_ssp585[m][i]
        x_predictor[m][i][:,nan_idx] = float('nan')

y_forced_response_concatenate = {}
x_predictor_concatenate = {}

for idx_m,m in enumerate(dic_processed_ssp585.keys()):
    y_forced_response_concatenate[m] = 0
    x_predictor_concatenate[m] = 0
    
    for idx_i, i in enumerate(dic_forced_response_ssp585[m].keys()):
        if idx_i ==0:
            y_forced_response_concatenate[m] = dic_forced_response_ssp585[m][i]
            x_predictor_concatenate[m] = dic_processed_ssp585[m][i]
        else:
            y_forced_response_concatenate[m] = np.concatenate([y_forced_response_concatenate[m],dic_forced_response_ssp585[m][i]])
            x_predictor_concatenate[m] = np.concatenate([x_predictor_concatenate[m], dic_processed_ssp585[m][i]],axis=0)  
    x_predictor_concatenate[m][:,nan_idx] = float('nan')


# compute the variance
variance_processed_ssp585 = {}
std_processed_ssp585 = {}
for idx_m,m in enumerate(x_predictor.keys()):
    variance_processed_ssp585[m] = {}
    arr_tmp = np.zeros((len(x_predictor[m].keys()),33))
    
    for idx_i, i in enumerate(list(dic_processed_ssp585[m].keys())):
        arr_tmp[idx_i,:] = np.nanmean(x_predictor[m][i],axis=1)

    arr_tmp_values = np.zeros((len(x_predictor[m].keys()),33))
    for idx_i, i in enumerate(x_predictor[m].keys()):
        arr_tmp_values[idx_i,:] = (y_forced_response[m][i] - arr_tmp[idx_i,:])**2

    # variance_processed_ssp585[m] = torch.nanmean(torch.from_numpy(arr_tmp_values),axis=0)
    variance_processed_ssp585[m] = torch.mean(torch.nanmean(torch.from_numpy(arr_tmp_values),axis=0))

# Data preprocessing
x_train = {}
y_train = {}

for idx_m,m in enumerate(dic_reduced_ssp585.keys()):
    x_train[m] = {}
    y_train[m] = {}
    for idx_i, i in enumerate(dic_processed_ssp585[m].keys()):
        x_train[m][i] = torch.nan_to_num(torch.from_numpy(x_predictor[m][i])).to(torch.float64)
        y_train[m][i] = torch.from_numpy(y_forced_response[m][i]).to(torch.float64)

Files and directories in ' /net/atmos/data/cmip6-ng/tos/ann/g025 ' :


  mean_ref_ensemble = np.nanmean(y_tmp,axis=0)/ len(dic_processed_ssp585[m].keys())
  mean_ref_ensemble += np.nanmean(y_tmp,axis=0)/ len(dic_processed_ssp585[m].keys())


In [2]:
class SAGA:
    def __init__(self, X, y, loss_fn, grad_loss_fn, lr=0.1, epochs=100):
        self.X = X
        self.y = y
        self.loss_fn = loss_fn
        self.grad_loss_fn = grad_loss_fn
        self.lr = lr
        self.epochs = epochs
        self.n_samples, self.n_features = X.shape
        self.w = torch.zeros(self.n_features, dtype=torch.float32, requires_grad=True)
        self.grad_store = torch.zeros((self.n_samples, self.n_features), dtype=torch.float32)

    def fit(self):
        for epoch in range(self.epochs):
            indices = torch.randperm(self.n_samples)
            avg_grad_w = torch.mean(self.grad_store, dim=0)
            
            for i in indices:
                xi = self.X[i]
                yi = self.y[i]
                
                pred = torch.dot(xi, self.w) 
                
                grad_w = self.grad_loss_fn(pred, yi, xi)
                
                new_w = self.w - self.lr * (grad_w - self.grad_store[i, :] + avg_grad_w)
                
                self.w = new_w
                
                self.grad_store[i, :] = grad_w.detach()

            # Compute and display the loss for the current epoch
            current_loss = self.compute_loss()
            print(f"Epoch {epoch + 1}/{self.epochs}, Loss: {current_loss:.4f}")

    def compute_loss(self):
        preds = self.X @ self.w 
        return self.loss_fn(preds, self.y).item()

    def predict(self, X):
        return X @ self.w 

In [51]:
# # Generate some synthetic data
# torch.manual_seed(42)
# X = torch.randn(100, 2)
# y = (torch.sigmoid(X[:, 0] * 2 - X[:, 1] * 3) > 0.5).float()

mu_=1.0

# Define the log-sum-exp loss function for squared residuals
def log_sum_exp_loss(preds, targets):

    res = torch.zeros(len(preds.keys()))
    for idx_m, m in enumerate(preds.keys()):            
        res[idx_m] = torch.mean((targets[m] - preds[m])**2/variance_processed_ssp585[m], axis=0)
        
    # squared_residuals = (preds - targets) ** 2
    return torch.logsumexp((1/mu_)*res, dim=0)

# Define the gradient of the log-sum-exp loss function for squared residuals
def log_sum_exp_grad(pred, y, x):
    
    res = torch.zeros(len(x.keys()))
    for idx_m, m in enumerate(x.keys()):            
        res[idx_m] = torch.mean((y[m] - pred[m])**2/variance_processed_ssp585[m],axis=0)
    # squared_residual = (pred - y) ** 2
    
    exp_term = torch.exp(res)
    grad_common_term = (2 * exp_term) / torch.sum(exp_term)
    print(grad_common_term.shape)

    grad_w = torch.zeros(len(pred.keys()), 2592)

    
    for idx_m, m in enumerate(pred.keys()):
        grad_w[idx_m,:] = grad_common_term[idx_m] * torch.matmul((y[m]*torch.ones(x[m].shape[0]) - pred[m]*torch.ones(x[m].shape[0])).to(torch.float64),torch.tensor(x[m]).to(torch.float64)) # x belongs to R^(M x (runs*T) x Grid)
    print(grad_w.shape)
    return torch.sum(grad_w,axis=0)

In [54]:
y_0 = {m: torch.ones(x_predictor_concatenate[m].shape[0]) for idx_m, m in enumerate(x_predictor.keys())}
y_1 = {m: torch.zeros(x_predictor_concatenate[m].shape[0]) for idx_m, m in enumerate(x_predictor.keys())}
log_sum_exp_loss(y_0, y_1)

tensor(1062.2227)

In [17]:
res = torch.zeros(len(x_predictor.keys()))
for idx_m, m in enumerate(x_predictor.keys()):     
    print(y_forced_response_concatenate[m].shape)
    print((y_forced_response_concatenate[m]/variance_processed_ssp585[m]).shape)
    res[idx_m] = torch.mean(y_forced_response_concatenate[m]/variance_processed_ssp585[m],axis=0)

(594,)
torch.Size([594])
(165,)
torch.Size([165])
(99,)
torch.Size([99])
(165,)
torch.Size([165])
(198,)
torch.Size([198])
(330,)
torch.Size([330])
(132,)
torch.Size([132])
(99,)
torch.Size([99])
(264,)
torch.Size([264])
(99,)
torch.Size([99])
(330,)
torch.Size([330])
(165,)
torch.Size([165])
(132,)
torch.Size([132])
(1650,)
torch.Size([1650])
(165,)
torch.Size([165])
(132,)
torch.Size([132])
(297,)
torch.Size([297])
(198,)
torch.Size([198])
(99,)
torch.Size([99])
(1320,)
torch.Size([1320])
(231,)
torch.Size([231])
(1650,)
torch.Size([1650])
(330,)
torch.Size([330])
(990,)
torch.Size([990])
(99,)
torch.Size([99])
(165,)
torch.Size([165])
