## Reduced Rank Regression

In [1]:
import os, sys
import pickle as pkl
import numpy as np
import random
import warnings
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler

sys.path.append(os.path.join(os.getcwd(), 'utils'))

from utils.data_loading import *
from utils.data_processing import *
from utils.regression import *
from utils.animation import *
from utils.metrics import *
from utils.pipeline import *

# autoreload
%reload_ext autoreload
%autoreload 2

# ignore warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
# Remove deprecation warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [2]:
# Define the data path and filename
data_path = os.path.join(os.getcwd(), 'data')
filename = 'ssp585_time_series.pkl'

# Preprocess the data
data, nan_mask = preprocess_data(data_path, filename)

Loading data from ssp585_time_series.pkl
Data loaded successfully.
Filtering data...


100%|██████████| 72/72 [00:00<00:00, 32968.33it/s]


Data filtered. Kept 34 models
Creating NaN mask...


100%|██████████| 34/34 [00:01<00:00, 18.32it/s]


NaN mask created.
Masking out NaN values...


100%|██████████| 34/34 [00:01<00:00, 19.10it/s]


NaN values masked out.
Reshaping data...


100%|██████████| 34/34 [00:03<00:00,  9.73it/s]


Data reshaped.
Adding the forced response to the data...


100%|██████████| 34/34 [00:19<00:00,  1.73it/s]


Forced response added.
Removing NaN values from the grid...


100%|██████████| 34/34 [00:03<00:00,  9.29it/s]


NaN values removed.


In [3]:
# Define the lambda values to test
# lambdas = [0.01, 0.1, 1, 10, 50, 100, 200]
# ranks = [1, 2, 5, 10, 50, 100]
lambdas = [1, 100]
ranks = [2, 10]

In [4]:
# Only keep 8 random models for the sake of time
random.seed(42)
models = random.sample(list(data.keys()), 3)
subset_data = {model: data[model] for model in models}
print(f"Models kept to test the pipeline: {models}")

Models kept to test the pipeline: ['GISS-E2-2-G', 'EC-Earth3', 'ACCESS-ESM1-5']


In [None]:
# Perform leave-one-out cross-validation
center = True
mse_distributions, mse_by_combination = loo_cross_validation(subset_data, lambdas, ranks, center=center)

  0%|          | 0/3 [00:00<?, ?it/s]

Normalizing data...


100%|██████████| 2/2 [00:00<00:00,  5.30it/s]
100%|██████████| 1/1 [00:00<00:00, 17.19it/s]


Data normalization completed.
Pooling data...


100%|██████████| 2/2 [00:00<00:00, 47662.55it/s]


Data pooled.
Performing leave-one-out cross-validation for model: GISS-E2-2-G
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.


 33%|███▎      | 1/3 [03:43<07:27, 223.74s/it]

Normalizing data...


100%|██████████| 2/2 [00:00<00:00,  6.97it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]


Data normalization completed.
Pooling data...


100%|██████████| 2/2 [00:00<00:00, 12052.60it/s]


Data pooled.
Performing leave-one-out cross-validation for model: EC-Earth3
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.


 67%|██████▋   | 2/3 [07:02<03:28, 208.82s/it]

Normalizing data...


100%|██████████| 2/2 [00:00<00:00,  9.12it/s]
100%|██████████| 1/1 [00:00<00:00,  4.15it/s]


Data normalization completed.
Pooling data...


100%|██████████| 2/2 [00:00<00:00, 52428.80it/s]


Data pooled.
Performing leave-one-out cross-validation for model: ACCESS-ESM1-5
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.
Fitting OLS...
RRR completed.


100%|██████████| 3/3 [09:44<00:00, 194.84s/it]


In [6]:
# Plot the mse distributions for each combination of lambda and rank
plot_mse_distributions(mse_by_combination, ranks, lambdas, output_dir='output')

Saved MSE distribution KDE plot at output/mse_distributions_kde.png


In [7]:
# Plot and save the MSE distributions for each model
plot_mse_distributions_per_model(mse_distributions, models, ranks, lambdas, output_dir='output')

Saved MSE distribution plot for model GISS-E2-2-G at output/mse_distributions_GISS-E2-2-G.png
Saved MSE distribution plot for model EC-Earth3 at output/mse_distributions_EC-Earth3.png
Saved MSE distribution plot for model ACCESS-ESM1-5 at output/mse_distributions_ACCESS-ESM1-5.png


In [8]:
# Select the most robust combination of rank and lambda
best_rank_lambda, best_mse = select_robust_hyperparameters(mse_by_combination, mean_weight = 0.7, variance_weight = 0.3, output_dir = 'output')

Saved best hyperparameters at output/best_hyperparameters.txt


In [9]:
# Extract the best rank and lambda
best_rank, best_lambda = best_rank_lambda
print(f"Selected best rank: {best_rank}, best lambda: {best_lambda}, with mean MSE: {best_mse:.4f}")

Selected best rank: 2, best lambda: 1, with mean MSE: 0.9269


In [10]:
# Perform final cross-validation using the best rank and lambda
final_mse_losses = final_cross_validation(subset_data, best_rank, best_lambda)

Final Cross-Validation:   0%|          | 0/3 [00:00<?, ?it/s]

Normalizing data...


100%|██████████| 2/2 [00:00<00:00,  5.52it/s]
100%|██████████| 1/1 [00:00<00:00, 16.91it/s]


Data normalization completed.
Pooling data...


100%|██████████| 2/2 [00:00<00:00, 17084.74it/s]


Data pooled.
Fitting OLS...
RRR completed.


Final Cross-Validation:  33%|███▎      | 1/3 [00:36<01:13, 36.94s/it]

Normalizing data...


100%|██████████| 2/2 [00:00<00:00,  7.26it/s]
100%|██████████| 1/1 [00:00<00:00,  8.67it/s]


Data normalization completed.
Pooling data...


100%|██████████| 2/2 [00:00<00:00, 64527.75it/s]


Data pooled.
Fitting OLS...
RRR completed.


Final Cross-Validation:  67%|██████▋   | 2/3 [01:06<00:32, 32.59s/it]

Normalizing data...


100%|██████████| 2/2 [00:00<00:00, 14.63it/s]
100%|██████████| 1/1 [00:00<00:00,  6.51it/s]


Data normalization completed.
Pooling data...


100%|██████████| 2/2 [00:00<00:00, 18196.55it/s]


Data pooled.
Fitting OLS...
RRR completed.


Final Cross-Validation: 100%|██████████| 3/3 [01:25<00:00, 28.66s/it]


In [11]:
plot_final_mse_distribution(final_mse_losses, output_dir='output')

Saved final MSE distribution plot at output/final_mse_distribution.png


In [12]:
# Chose a random model to test on
test_model = random.choice(models)

# Generate and save animations for the test model
generate_and_save_animations(
    data=subset_data,
    test_model=test_model,
    best_rank=best_rank,
    best_lambda=best_lambda,
    nan_mask=nan_mask,
    num_runs=3,
    output_dir="output",
    color_limits=(-2, 2)
)

Normalizing data...


100%|██████████| 2/2 [00:00<00:00,  8.11it/s]
100%|██████████| 1/1 [00:00<00:00, 22.75it/s]


Data normalization completed.
Pooling data...


100%|██████████| 2/2 [00:00<00:00, 32768.00it/s]


Data pooled.
Fitting OLS...
RRR completed.
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Re-adding NaN values to the grid...
Animations saved in output/animations
