In [29]:
import os
import sys
maindir = os.getcwd()
sys.path.append(maindir+"/src")


import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt

from preprocessing import data_processing, compute_anomalies_and_scalers, \
                            compute_forced_response, \
                            numpy_to_torch, rescale_and_merge_training_and_test_sets, \
                            rescale_training_and_test_sets

# Load data 

In [30]:
############### Load climate model raw data for SST
with open('data/ssp585_time_series.pkl', 'rb') as f:
    data = pickle.load(f)

###################### Load longitude and latitude 
with open('data/lon.npy', 'rb') as f:
    lon = np.load(f)

with open('data/lat.npy', 'rb') as f:
    lat = np.load(f)

# define grid (+ croping for latitude > 60)
lat_grid, lon_grid = np.meshgrid(lat[lat<=60], lon, indexing='ij')

lat_size = lat_grid.shape[0]
lon_size = lon_grid.shape[1]
time_period=34 # 1981-2015

## Data preprocessing

In [31]:
# define pytorch precision
dtype = torch.float32

data_processed, notnan_idx, nan_idx = data_processing(data, lon, lat,max_models=100)
x, means, vars = compute_anomalies_and_scalers(data_processed, lon_size, lat_size, nan_idx, time_period=34)
y = compute_forced_response(data_processed, lon_size, lat_size, nan_idx, time_period=34)

x,y, means, vars = numpy_to_torch(x,y,means,vars, dtype=dtype)

  means[m] = np.nanmean(data_reshaped[m],axis=0)
  vars[m] = np.nanvar(data_reshaped[m],axis=0)
  mean_spatial_ensemble = np.nanmean(y_tmp,axis=0)


### Build training and test sets by removing a singe model m0

In [32]:
m0= 'ICON-ESM-LR'
training_models, x_rescaled, y_rescaled = rescale_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)
training_models, x_train, y_train, x_test, y_test = rescale_and_merge_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)

### import ML algorithms 

In [33]:
from algorithms import ridge_regression, ridge_regression_low_rank, low_rank_projection, \
                        prediction, compute_gradient, train_robust_weights_model, compute_weights

from leave_one_out import leave_one_out_single, leave_one_out_procedure

# We would like to solve the problem with trace norm regularizer
## $\min_{W} \sum_{m} \lVert Y^m - X^m W \rVert_F^2 + \lambda \lVert W \rVert_F^2 + \nu \lVert W \rVert_*$

In [None]:
# to solve this problem we need to compute the proximal operator of elastic net
# we will use the proximal gradient descent algorithm

# we will use the following algorithm:
# 1. initialize the weights
# 2. compute the gradient
# 3. update the weights
# 4. compute the proximal operator
# 5. repeat 2-4 until convergence

####################### HERE WE DEFINE THE PROXIMAL OPERATORS ######################

def ridge_regression(X, Y, lambda_=1.0,dtype=torch.float32,verbose=False):
    """
    Computes the closed-form solution for reduced rank regression.
    
    Args:
        X (torch.Tensor): Predictor matrix of shape (n, p).
        Y (torch.Tensor): Response matrix of shape (n, q).
        lambda_ (scalar): Ridge penalty coefficient.
        
    Returns:
        U (torch.Tensor): Low-rank predictor coefficients of shape (p, rank).
        V (torch.Tensor): Low-rank response coefficients of shape (q, rank).
    """

    # compute Penroe Morose pseudo inverse of X^T @ X
    P = torch.linalg.inv(X.T @ X + lambda_ * torch.eye(X.shape[1],dtype=dtype))
    
    # compute ordinary least square solution 
    W_ols = P @ X.T @ Y

    # print loss function 
    if verbose:
        loss = torch.norm(Y - X @ W_ols,p='fro')**2 + lambda_ * torch.norm(W_ols,p='fro')**2
        print("Loss function: ", loss.item())
    return W_ols


def singular_value_thresholding(D, nu_):
    """Singular Value Thresholding (SVT) operator: D -> U * S_nu * V^T"""
    U, S, V = torch.svd(D)
    S_nu = torch.clamp(S - nu_, min=0)  # Soft-thresholding on singular values
    return U @ torch.diag(S_nu) @ V.t()



def ridge_and_trace_norm_minimization(X,Y,lambda_,nu_):
    """Compute the proximal operator of the elastic net penalty.
       argmin_(W) 1/2 ||Y - XW||_F^2 + lambda_ ||W||_F^2 + nu_ ||W||_* = 
       SingValue Soft-thresholding( (lambda I + X^T X )^{-1} X^T Y, nu_/lambda ))
    """
    W = ridge_regression(X,Y,lambda_,verbose=True)
    return singular_value_thresholding(W, nu_/lambda_)

In [52]:
# compute the proximal operator of the elastic net penalty
lambda_ = 100.0
nu_ = 100.0

W_ridge = torch.zeros(x_train.shape[1],y_train.shape[1],dtype=dtype)
W_ridge[np.ix_(notnan_idx,notnan_idx)] = ridge_regression(x_train[:,notnan_idx], y_train[:,notnan_idx], lambda_, verbose=False)

W_ridge_lr = torch.zeros(x_train.shape[1],y_train.shape[1],dtype=dtype)
W_ridge_lr[np.ix_(notnan_idx,notnan_idx)] = ridge_and_trace_norm_minimization(x_train[:,notnan_idx], y_train[:,notnan_idx], lambda_, nu_)

# compute the prediction on test climate model 
y_pred_ridge = torch.zeros_like(y_test)
y_pred_ridge[:,nan_idx] = float('nan')
y_pred_ridge[:,notnan_idx] = x_test[:,notnan_idx] @ W_ridge[np.ix_(notnan_idx,notnan_idx)]

# compute the prediction on test climate model with low rank approximation
y_pred_ridge_lr = torch.zeros_like(y_test)
y_pred_ridge_lr[:,nan_idx] = float('nan')
y_pred_ridge_lr[:,notnan_idx] = x_test[:,notnan_idx] @ W_ridge_lr[np.ix_(notnan_idx,notnan_idx)]


# compute rmse with respect to the true response
rmse_ridge = torch.sqrt(torch.nanmean((y_test-y_pred_ridge)**2))
rmse_ridge_lr = torch.sqrt(torch.nanmean((y_test-y_pred_ridge_lr)**2))

print("RMSE Ridge: ", rmse_ridge.item())
print("RMSE Ridge Low Rank: ", rmse_ridge_lr.item())

# check the rank of the matrices
print("Rank of W_ridge: ", torch.linalg.matrix_rank(W_ridge))
print("Rank of W_ridge_lr: ", torch.linalg.matrix_rank(W_ridge_lr))

Loss function:  112849.109375
RMSE Ridge:  3.4417343139648438
RMSE Ridge Low Rank:  1.157891869544983
Rank of W_ridge:  tensor(1298)
Rank of W_ridge_lr:  tensor(4)


## We would like to solve the problem with Trace norm regularization.
## $\min_{W}  \mu \log \left(\sum_{m} \exp(\frac{1}{\mu} \Vert Y^m - X^m W\Vert_F^2 ) \right) + \lambda \lVert W \rVert_*$
##
## Two options: 
### 1 - Solve the problem using variational formulation ($\eta$-trick)
### 2 - Solve the proble using accelerated gradient descent of Ji et al. 2009.

In [None]:
##### Use variational formulation ###############
def compute_gradient_trace_norm(models,x,y,w,B,notnan_idx,lambda_=1.0,mu_=1.0, gamma_=1.0):
    """This function computes the gradient of ridge log-sum-exp loss with respect to W + ridge regularization + trace norm rgularizer.

    Args:
        - x, y: input-output pair
        - w: regressor matrix
        - B: positive definite matrix used in the variation 
        
    Returns:
        - Gradient matrix: torch.tensor d x d
    """
    res = torch.zeros(len(models), w.shape[0], w.shape[0]).to(dtype)
    res_sumexp = torch.zeros(len(models)).to(dtype)

    print("Gradient start to be computed")
    for idx_m, m in enumerate(models):

        # compute -2X_{m,r}^T (Y_{m,r}^T - X_{m,r}^T W)
        res[idx_m][np.ix_(notnan_idx,notnan_idx)] = - 2*torch.mean(torch.bmm(torch.transpose(x[m][:,:,notnan_idx], 1,2) , \
                                                        y[m][:,:,notnan_idx] - x[m][:,:,notnan_idx] @ w[np.ix_(notnan_idx,notnan_idx)]),dim=0)

        # compute the exponential term
        res_sumexp[idx_m] = (1/mu_)*torch.mean(torch.norm(y[m][:,:,notnan_idx] - x[m][:,:,notnan_idx] @ w[np.ix_(notnan_idx,notnan_idx)],p='fro',dim=(1,2))**2)
            
    softmax = torch.nn.Softmax(dim=0)
    res_sumexp = softmax(res_sumexp)

    # compute gradient as sum (res * softmax)
    grad = torch.sum(torch.unsqueeze(torch.unsqueeze(res_sumexp,-1),-1) * res, dim=0)
    grad[np.ix_(notnan_idx,notnan_idx)] = grad[np.ix_(notnan_idx,notnan_idx)] + 2*lambda_* w[np.ix_(notnan_idx,notnan_idx)]
    
    return grad 

def train_robust_weights_trace_norm(models,x,y,notnan_idx,lambda_=1.0,mu_=1.0,gamma_=1.0,lr=0.1,nb_iterations=10):
    """This function computes the gradient of ridge log-sum-exp loss with respect to W.

       Args:
            
       Returns:
    """
    w = torch.zeros(lon_size*lat_size,lon_size*lat_size).to(dtype)
    B = torch.eye(w.shape[0]).to(dtype)
    w_old = torch.zeros(lon_size*lat_size,lon_size*lat_size).to(dtype)

    training_loss = torch.zeros(nb_iterations)
    
    # run a simple loop
    for it in range(nb_iterations):


        # accelerate gradient descent
        if it > 1:
            w_tmp = w + ((it-1)/(it+2)) * (w - w_old)
        else:
            w_tmp = w.detach()

        # save old parameter
        w_old = w.clone().detach()

        # compute gradient
        print(" Compute gradient ")
        
        grad = compute_gradient_trace_norm(models,x,y,w_tmp,B,notnan_idx,lambda_,mu_,gamma_)

        ######################### Update coordinates ###############
        # update the variable w
        w = w_tmp - lr * grad

        print(" Update intermediate variable ")
        # update variable B as square root (W W^T + gamma * I)^(1/2)
        B = sqrtm_evd(w @ w.T +  gamma_* torch.eye(w.shape[0],dtype=dtype))

        # compute loss functon to check convergence 
        res = torch.zeros(len(models))
        
        print(" Compute loss function ")
        # compute loss functon to check convergence 
        res = torch.zeros(len(models))

        for idx_m, m in enumerate(models):

            # compute residuals
            res[idx_m] = torch.mean(torch.norm(y[m][:,:,notnan_idx] -x[m][:,:,notnan_idx] @ w[notnan_idx,:][:,notnan_idx], p='fro',dim=(1,2))**2,dtype=dtype)
    
        obj = mu_*torch.logsumexp((1/mu_)* res,0)
        obj += lambda_*torch.norm(w,p='fro')**2

        # add trace norm regularization
        obj_tmp = 0.5*gamma_*torch.trace(w.T @ torch.linalg.pinv(B) @ w) 
        # obj_tmp += 0.5*gamma_* torch.trace(torch.linalg.pinv(B)) 
        # obj_tmp += 0.5*gamma_*0.1 *torch.trace(B) 
        

        print("Iteration ", it,  ": Loss function : ", (obj+obj_tmp).item())
        print("Rank of w: ", torch.linalg.matrix_rank(w))
        print("Nuclear norm of w: ", torch.norm(w, p='nuc'))
        print("Variational forumlation of the trace norm: ", obj_tmp.item())
        
        training_loss[it] = (obj+obj_tmp).item()

    plt.close('all')
    plt.figure()
    plt.plot(range(nb_iterations),training_loss)
    plt.title('Training loss')
    plt.ylabel('Loss')
    plt.xlabel('Iterations')
    plt.show()
    
    return w, training_loss


In [12]:
lambda_tmp = 100.0
mu_tmp = 1000.0

w_robust, training_loss = train_robust_weights_trace_norm(training_models,x_rescaled,y_rescaled,notnan_idx,lambda_=lambda_tmp,mu_=mu_tmp,gamma_=0.5,lr=1e-5,nb_iterations=100)

 Compute gradient 
Gradient start to be computed
 Update intermediate variable 
 Compute loss function 
Iteration  0 : Loss function :  97474.59375
Rank of w:  tensor(136)
Nuclear norm of w:  tensor(2.9862)
Variational forumlation of the trace norm:  0.2411424070596695
 Compute gradient 
Gradient start to be computed
 Update intermediate variable 
 Compute loss function 
Iteration  1 : Loss function :  88058.046875
Rank of w:  tensor(272)
Nuclear norm of w:  tensor(5.0253)
Variational forumlation of the trace norm:  0.33060672879219055
 Compute gradient 
Gradient start to be computed
 Update intermediate variable 
 Compute loss function 
Iteration  2 : Loss function :  81736.0546875
Rank of w:  tensor(274)
Nuclear norm of w:  tensor(7.3241)
Variational forumlation of the trace norm:  0.4835149943828583
 Compute gradient 
Gradient start to be computed
 Update intermediate variable 
 Compute loss function 
Iteration  3 : Loss function :  72835.28125
Rank of w:  tensor(400)
Nuclear norm o

KeyboardInterrupt: 

In [None]:
##### code of paper of Ji et al. 2009 about accelerated gradient descent.

def singular_value_thresholding(D, tau):
    """Singular Value Thresholding (SVT) operator: D -> U * S_tau * V^T"""
    U, S, V = torch.svd(D)
    S_tau = torch.clamp(S - tau, min=0)  # Soft-thresholding on singular values
    return U @ torch.diag(S_tau) @ V.t()

def accelerated_trace_norm_minimization(models,x,y, lambda_=1.0,mu_=1.0,tau=1.0, max_iter=500, tol=1e-6):
    """
    Implements the Accelerated Gradient Method for Trace Norm Minimization
    from Ji & Ye (2009).
    
    Solves:
        min_W (1/2) ||W - M||_F^2 + tau ||W||_*

    Parameters:
    - M: Input matrix (torch.Tensor)
    - tau: Regularization parameter (controls nuclear norm weight)
    - max_iter: Maximum number of iterations
    - tol: Convergence tolerance

    Returns:
    - W_k: Optimized low-rank matrix
    """
    # Initialize variables
    Z_k = torch.zeros(lon_size*lat_size,lon_size*lat_size).to(torch.float64)  # Z_k (momentum variable)
    W_k = torch.zeros(lon_size*lat_size,lon_size*lat_size).to(torch.float64) # W_k (solution)
    t_k = 1  # Momentum parameter

    for k in range(max_iter):
        W_k_prev = W_k.clone()  # Store previous iterate W_{k-1}

        # Gradient step: Compute G_k = Z_k - M
        # G_k = Z_k - M  # Gradient of (1/2) ||W - M||_F^2
        print("Compute gradient ")
        G_k = compute_gradient(models,x,y,W_k,notnan_idx,lambda_,mu_)


        print("Compute proximal step ")
        # Proximal step: Apply singular value thresholding (SVT)
        W_k = singular_value_thresholding(Z_k - G_k, tau)

        print("Momentum update ")
        # Momentum update
        t_k_next = 0.5 * (1 + torch.sqrt(1 + 4 * torch.tensor(t_k) ** 2))  # Compute t_{k+1}
        Z_k = W_k + ((t_k - 1) / t_k_next) * (W_k - W_k_prev)  # Update Z_k
        t_k = t_k_next  # Update t_k

        # print loss function
        # 0.5 * ||X - A||_F^2 + lambda * ||X||_*
        # loss =  0.5* torch.norm(W_k - M,p='fro')**2 + tau*torch.norm(W_k,p='nuc')

        # compute loss functon to check convergence 
        res = torch.zeros(len(models))
        
        for idx_m, m in enumerate(models):  
            
             # compute residuals
            res[idx_m] = torch.mean(torch.norm(y[m][:,:,notnan_idx] -x[m][:,:,notnan_idx] @ W_k[notnan_idx,:][:,notnan_idx], p='fro',dim=(1,2))**2)
                
            
        obj = mu_*torch.logsumexp((1/mu_)* res,0)
        obj += lambda_*torch.norm(W_k,p='fro')**2
        obj += tau_*torch.norm(W_k,p='nuc')
        loss = obj
        
        print("Iteration ", k, ": ", loss.item())

        # Convergence check
        if torch.norm(W_k - W_k_prev, p="fro") < tol:
            break

    return W_k


tau_ = 10000  # Regularization parameter
lambda_ = 100.0
mu_ = 1000.0
W_opt = accelerated_trace_norm_minimization(training_models,x_stacked,y_stacked, lambda_,mu_, tau_, max_iter=500, tol=1e-7)

Compute gradient 
Compute proximal step 
Momentum update 
Iteration  0 :  39652467671040.0
Compute gradient 
Compute proximal step 
Momentum update 


  t_k_next = 0.5 * (1 + torch.sqrt(1 + 4 * torch.tensor(t_k) ** 2))  # Compute t_{k+1}


Iteration  1 :  1.0877854027642322e+23
Compute gradient 
Compute proximal step 
Momentum update 
Iteration  2 :  2.9890104802734357e+32
Compute gradient 
Compute proximal step 


KeyboardInterrupt: 

In [None]:
w_robust, training_loss = train_robust_weights_model(training_models,x,y,lon_size,lat_size,notnan_idx,rank=None,lambda_=1.0,mu_=1.0,lr=0.000001,nb_iterations=2)

In [None]:
w_robust, y_pred, y_test, rmse_train = leave_one_out_single(m0,x,y,vars,\
                                                              lon_size,lat_size,notnan_idx,nan_idx,\
                                                              lr=0.000001,nb_gradient_iterations=2,time_period=33,\
                                                              rank=5,lambda_=100.0,method='robust',mu_=1000.0,verbose=True)

## try to optimize with nuclear norm

In [None]:
def train_robust_model_autograd(x,y,lon_size,lat_size,models,lambda_=1.0,mu_=1.0,nbEpochs=100,verbose=True):
    """
    Learn parameter β such that β = argmin( log Σ_m exp(||y_m - X_m^T β||^2) ).

    Args:
        - x,y : location, observation 
        - lon_size, lat_size: longitude and latitude grid size (Int)
        - models: (sub)list of models (list)
        - mu_: softmax coefficient (float)
        - nbepochs: number of optimization steps (Int)
        - verbose: display logs (bool)
    """

    # define variable beta
    w = torch.zeros(lon_size*lat_size,lon_size*lat_size).to(torch.float64)
    w.requires_grad_(True) 

    # mat_eta = torch.eye(w.shape[0],w.shape[0]).to(torch.float64)
    # mat_eta.requires_grad_(True) 

    # define optimizer
    optimizer = torch.optim.Adam([w],lr=1e-5)

    # stopping criterion
    criteria = torch.tensor(0.0)
    criteria_tmp = torch.tensor(1.0) 
    epoch = 0
    training_loss = torch.zeros(nbEpochs)
    
            
    # --- optimization loop ---                
    while (epoch < nbEpochs):

        # update criteria
        criteria_tmp = criteria.clone()
                      
        optimizer.zero_grad()
        ############### Define loss function ##############
        res = torch.zeros(len(models))

        for idx_m, m in enumerate(models):  
            for idx_r, r in enumerate(x[m].keys()):
                res[idx_m] += torch.sum((y[m][r][:,notnan_idx] -x[m][r][:,notnan_idx] @ w[notnan_idx,:][:,notnan_idx] )**2)
                
            res[idx_m] = res[idx_m]/len(x[m].keys())
            
        # obj = mu_*torch.logsumexp((1/mu_)* res,0)

        # Compute the nuclear norm (sum of singular values)
        U, S, V = torch.svd(w)  # Singular Value Decomposition

        # check if it works with simple linear regression
        obj = torch.sum(res)
        # obj += 0.5*lambda_*( torch.trace(w @ torch.linalg.inv(mat_eta) @ w.T) + torch.trace(mat_eta))
        obj +=  lambda_* S.sum()
        # obj += lambda_*torch.norm(w,p='nuc')

        
        #define loss function
        loss = obj

        # set the training loss
        training_loss[epoch] = loss.detach().item()
                    
        # Use autograd to compute the backward pass. 
        loss.backward()               
        
        # take a step into optimal direction of parameters minimizing loss
        optimizer.step() 

        # print rank of matrix W
        print("Rank of the matrix : ", torch.linalg.matrix_rank(w))

        if(verbose==True):
            if(epoch % 1 == 0):
                print('Epoch ', epoch, 
                        ', loss=', training_loss[epoch].detach().item()
                        )
        criteria = loss
        epoch +=1
    
    plt.figure()
    plt.plot(range(nbEpochs),training_loss)
    plt.title('Training loss')
    plt.ylabel('Loss')
    plt.xlabel('iterations')
    plt.show()
    
    return w

In [None]:
w = train_robust_model_autograd(x,y,lon_size,lat_size,training_models,lambda_=10000.0,mu_=10.0,nbEpochs=100,verbose=True)