# Set up

In [21]:
import os
import random
import numpy as np
import xarray as xr

import torch 
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F 
from torch.utils.data import DataLoader, TensorDataset 

from mlp import *
from unet import *
from lr import *
from utils import *
from bos import *

In [2]:
# Ensure Experiment Reproducibility
RANDOM_SEED = 42

# Set random seeds for reproducibility
random.seed(RANDOM_SEED) 
np.random.seed(RANDOM_SEED) 
torch.manual_seed(RANDOM_SEED) 
torch.cuda.manual_seed(RANDOM_SEED)

# If CUDA (GPU acceleration) is available, set seeds for all GPUs
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED) # Sets seed for current PyTorch GPU
    torch.cuda.manual_seed_all(RANDOM_SEED) # Sets seed for all available GPUs

    # Additional settings for CUDA operations
    torch.backends.cudnn.deterministic = True  # Makes cuDNN deterministic (reproducible)
    torch.backends.cudnn.benchmark = False # Disables cudnn benchmarking, which automatically selects the 

# Data Preparation

In [3]:
PERSISTENT_BUCKET = 'gs://leap-persistent/jingwenlyuu/S2S'

# Week 1 data
week1_gefs_1989_2010 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week1/GEFS_pra_1989-2010.zarr').rename({'Y': 'lat', 'X': 'lon'}).mean(dim='M').prcp.transpose('S','lat','lon').values
week1_gefs_2011_2018 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week1/GEFS_pra_2011-2018.zarr').rename({'Y': 'lat', 'X': 'lon'}).mean(dim='M').prcp.transpose('S','lat','lon').values
week1_imd_1989_2010 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week1/IMD_pra_1989-2010.zarr').rename({'Y': 'lat', 'X': 'lon'}).prcp.transpose('T','lat','lon').values
week1_imd_2011_2018 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week1/IMD_pra_2011-2018.zarr').rename({'Y': 'lat', 'X': 'lon'}).prcp.transpose('T','lat','lon').values

# Week 2 data
week2_gefs_1989_2010 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week2/GEFS_pra_1989-2010.zarr').rename({'Y': 'lat', 'X': 'lon'}).mean(dim='M').prcp.transpose('S','lat','lon').values
week2_gefs_2011_2018 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week2/GEFS_pra_2011-2018.zarr').rename({'Y': 'lat', 'X': 'lon'}).mean(dim='M').prcp.transpose('S','lat','lon').values
week2_imd_1989_2010 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week2/IMD_pra_1989-2010.zarr').rename({'Y': 'lat', 'X': 'lon'}).prcp.transpose('T','lat','lon').values
week2_imd_2011_2018 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week2/IMD_pra_2011-2018.zarr').rename({'Y': 'lat', 'X': 'lon'}).prcp.transpose('T','lat','lon').values

# Week 3-4 data
week34_gefs_1989_2010 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week34/GEFS_pra_1989-2010.zarr').rename({'Y': 'lat', 'X': 'lon'}).mean(dim='M').prcp.transpose('S','lat','lon').values
week34_gefs_2011_2018 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week34/GEFS_pra_2011-2018.zarr').rename({'Y': 'lat', 'X': 'lon'}).mean(dim='M').prcp.transpose('S','lat','lon').values
week34_imd_1989_2010 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week34/IMD_pra_1989-2010.zarr').rename({'Y': 'lat', 'X': 'lon'}).prcp.transpose('T','lat','lon').values
week34_imd_2011_2018 = xr.open_zarr(f'{PERSISTENT_BUCKET}/training_data/week34/IMD_pra_2011-2018.zarr').rename({'Y': 'lat', 'X': 'lon'}).prcp.transpose('T','lat','lon').values

In [4]:
# Constants
N_VAL = 70  # Half of 2011-2018 data

# Week 1 splits
# Training data (1989-2010)
week1_train_input = week1_gefs_1989_2010
week1_train_output = week1_imd_1989_2010

# Validation data (first half of 2011-2018)
week1_val_input = week1_gefs_2011_2018[:N_VAL]
week1_val_output = week1_imd_2011_2018[:N_VAL]

# Test data (second half of 2011-2018)
week1_test_input = week1_gefs_2011_2018[N_VAL:]
week1_test_output = week1_imd_2011_2018[N_VAL:]

# Combined train_val
week1_train_val_input = np.concatenate([week1_train_input, week1_val_input], axis=0)
week1_train_val_output = np.concatenate([week1_train_output, week1_val_output], axis=0)

# Combined val_test
week1_val_test_input = np.concatenate([week1_val_input, week1_test_input], axis=0)
week1_val_test_output = np.concatenate([week1_val_output, week1_test_output], axis=0)

# Week 2 splits
# Training data (1989-2010)
week2_train_input = week2_gefs_1989_2010
week2_train_output = week2_imd_1989_2010

# Validation data (first half of 2011-2018)
week2_val_input = week2_gefs_2011_2018[:N_VAL]
week2_val_output = week2_imd_2011_2018[:N_VAL]

# Test data (second half of 2011-2018)
week2_test_input = week2_gefs_2011_2018[N_VAL:]
week2_test_output = week2_imd_2011_2018[N_VAL:]

# Combined train_val
week2_train_val_input = np.concatenate([week2_train_input, week2_val_input], axis=0)
week2_train_val_output = np.concatenate([week2_train_output, week2_val_output], axis=0)

# Combined val_test
week2_val_test_input = np.concatenate([week2_val_input, week2_test_input], axis=0)
week2_val_test_output = np.concatenate([week2_val_output, week2_test_output], axis=0)

# Week 3-4 splits
# Training data (1989-2010)
week34_train_input = week34_gefs_1989_2010
week34_train_output = week34_imd_1989_2010

# Validation data (first half of 2011-2018)
week34_val_input = week34_gefs_2011_2018[:N_VAL]
week34_val_output = week34_imd_2011_2018[:N_VAL]

# Test data (second half of 2011-2018)
week34_test_input = week34_gefs_2011_2018[N_VAL:]
week34_test_output = week34_imd_2011_2018[N_VAL:]

# Combined train_val
week34_train_val_input = np.concatenate([week34_train_input, week34_val_input], axis=0)
week34_train_val_output = np.concatenate([week34_train_output, week34_val_output], axis=0)

# Combined val_test
week34_val_test_input = np.concatenate([week34_val_input, week34_test_input], axis=0)
week34_val_test_output = np.concatenate([week34_val_output, week34_test_output], axis=0)

# Raw Forecasts (Without Bias Correction)

In [5]:
mse_raw_week1 = ((week1_test_input - week1_test_output)**2 ).mean(axis=0)
mse_raw_week2 = ((week2_test_input - week2_test_output)**2 ).mean(axis=0)
mse_raw_week34 = ((week34_test_input - week34_test_output)**2 ).mean(axis=0)

In [6]:
# Calculate correlations for each lead time
corr_raw_week1 = calc_spatial_correlation(week1_test_input, week1_test_output)
corr_raw_week2 = calc_spatial_correlation(week2_test_input, week2_test_output)
corr_raw_week34 = calc_spatial_correlation(week34_test_input, week34_test_output)

# LR

In [7]:
pre_lr_week1 = perform_linear_regression(week1_train_val_input, week1_train_val_output, week1_test_input, week1_test_output)
pre_lr_week2 = perform_linear_regression(week2_train_val_input, week2_train_val_output, week2_test_input, week2_test_output)
pre_lr_week34 = perform_linear_regression(week34_train_val_input, week34_train_val_output, week34_test_input, week34_test_output)

mse_lr_week1 = ((pre_lr_week1 - week1_test_output)**2 ).mean(axis=0)
mse_lr_week2 = ((pre_lr_week2 - week2_test_output)**2 ).mean(axis=0)
mse_lr_week34 = ((pre_lr_week34 - week34_test_output)**2 ).mean(axis=0)

cc_lr_week1 = calc_spatial_correlation(pre_lr_week1, week1_test_output)
cc_lr_week2 = calc_spatial_correlation(pre_lr_week2, week2_test_output)
cc_lr_week34 = calc_spatial_correlation(pre_lr_week34, week34_test_output)

ss_lr_week1 = 1 - mse_lr_week1/mse_raw_week1
ss_lr_week2 = 1 - mse_lr_week2/mse_raw_week2
ss_lr_week34 = 1 - mse_lr_week34/mse_raw_week34

# CNN

## data

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32

# week 1
x_train_cnn_week1 = torch.from_numpy(week1_train_input).float().unsqueeze(1).to(device)
y_train_cnn_week1 = torch.from_numpy(week1_train_output).float().unsqueeze(1).to(device)

x_val_cnn_week1 = torch.from_numpy(week1_val_input).float().unsqueeze(1).to(device)
y_val_cnn_week1 = torch.from_numpy(week1_val_output).float().unsqueeze(1).to(device)

x_test_cnn_week1 = torch.from_numpy(week1_test_input).float().unsqueeze(1).to(device)
y_test_cnn_week1 = torch.from_numpy(week1_test_output).float().unsqueeze(1).to(device)

train_dataset_cnn_week1 = TensorDataset(x_train_cnn_week1, y_train_cnn_week1)
val_dataset_cnn_week1 = TensorDataset(x_val_cnn_week1, y_val_cnn_week1)
test_dataset_cnn_week1 = TensorDataset(x_test_cnn_week1, y_test_cnn_week1)


train_loader_cnn_week1 = DataLoader(train_dataset_cnn_week1, batch_size=batch_size, shuffle=True)
val_loader_cnn_week1 = DataLoader(val_dataset_cnn_week1, batch_size=batch_size, shuffle=False)
test_loader_cnn_week1 = DataLoader(test_dataset_cnn_week1, batch_size=batch_size, shuffle=False)


# week 2
x_train_cnn_week2 = torch.from_numpy(week2_train_input).float().unsqueeze(1).to(device)
y_train_cnn_week2 = torch.from_numpy(week2_train_output).float().unsqueeze(1).to(device)

x_val_cnn_week2 = torch.from_numpy(week2_val_input).float().unsqueeze(1).to(device)
y_val_cnn_week2 = torch.from_numpy(week2_val_output).float().unsqueeze(1).to(device)

x_test_cnn_week2 = torch.from_numpy(week2_test_input).float().unsqueeze(1).to(device)
y_test_cnn_week2 = torch.from_numpy(week2_test_output).float().unsqueeze(1).to(device)

train_dataset_cnn_week2 = TensorDataset(x_train_cnn_week2, y_train_cnn_week2)
val_dataset_cnn_week2 = TensorDataset(x_val_cnn_week2, y_val_cnn_week2)
test_dataset_cnn_week2 = TensorDataset(x_test_cnn_week2, y_test_cnn_week2)

train_loader_cnn_week2 = DataLoader(train_dataset_cnn_week2, batch_size=batch_size, shuffle=True)
val_loader_cnn_week2 = DataLoader(val_dataset_cnn_week2, batch_size=batch_size, shuffle=False)
test_loader_cnn_week2 = DataLoader(test_dataset_cnn_week2, batch_size=batch_size, shuffle=False)


# week 34
x_train_cnn_week34 = torch.from_numpy(week34_train_input).float().unsqueeze(1).to(device)
y_train_cnn_week34 = torch.from_numpy(week34_train_output).float().unsqueeze(1).to(device)

x_val_cnn_week34 = torch.from_numpy(week34_val_input).float().unsqueeze(1).to(device)
y_val_cnn_week34 = torch.from_numpy(week34_val_output).float().unsqueeze(1).to(device)

x_test_cnn_week34 = torch.from_numpy(week34_test_input).float().unsqueeze(1).to(device)
y_test_cnn_week34 = torch.from_numpy(week34_test_output).float().unsqueeze(1).to(device)

train_dataset_cnn_week34 = TensorDataset(x_train_cnn_week34, y_train_cnn_week34)
val_dataset_cnn_week34 = TensorDataset(x_val_cnn_week34, y_val_cnn_week34)
test_dataset_cnn_week34 = TensorDataset(x_test_cnn_week34, y_test_cnn_week34)

train_loader_cnn_week34 = DataLoader(train_dataset_cnn_week34, batch_size=batch_size, shuffle=True)
val_loader_cnn_week34 = DataLoader(val_dataset_cnn_week34, batch_size=batch_size, shuffle=False)
test_loader_cnn_week34 = DataLoader(test_dataset_cnn_week34, batch_size=batch_size, shuffle=False)

## train

In [20]:
# set_random_seed(42)

# criterion = NaNMSELoss()

# model_cnn = UNet().to(device)
# optimizer = torch.optim.Adam(model_cnn.parameters(), lr=0.001)

# cnn_time_1 = train_cnn(model_cnn, train_loader_cnn_week1, val_loader_cnn_week1, criterion, optimizer, device,
#             save_path ='/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/cnn.pth', n_epochs=300, patience=10)

# cnn_time_2 = train_cnn(model_cnn, train_loader_cnn_week2, val_loader_cnn_week2, criterion, optimizer, device,
#             save_path ='/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week2/cnn.pth', n_epochs=300, patience=10)

# cnn_time_3 = train_cnn(model_cnn, train_loader_cnn_week34, val_loader_cnn_week34, criterion, optimizer, device,
#             save_path ='/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week34/cnn.pth', n_epochs=300, patience=10)

## Evaluation

In [11]:
pre_cnn_week1 = evaluate_cnn(model_cnn, device, test_loader_cnn_week1, '/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/cnn.pth')
mse_cnn_week1 = ((pre_cnn_week1 - week1_test_output)**2 ).mean(axis=0)
cc_cnn_week1 = calc_spatial_correlation(pre_cnn_week1, week1_test_output)
ss_cnn_week1 = 1 - mse_cnn_week1/mse_raw_week1

pre_cnn_week2 = evaluate_cnn(model_cnn, device, test_loader_cnn_week2, '/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week2/cnn.pth')
mse_cnn_week2 = ((pre_cnn_week2 - week2_test_output)**2 ).mean(axis=0)
cc_cnn_week2 = calc_spatial_correlation(pre_cnn_week2, week2_test_output)
ss_cnn_week2 = 1 - mse_cnn_week2/mse_raw_week2

pre_cnn_week34 = evaluate_cnn(model_cnn, device, test_loader_cnn_week34, '/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week34/cnn.pth')
mse_cnn_week34 = ((pre_cnn_week34 - week34_test_output)**2 ).mean(axis=0)
cc_cnn_week34 = calc_spatial_correlation(pre_cnn_week34, week34_test_output)
ss_cnn_week34 = 1 - mse_cnn_week34/mse_raw_week34

  checkpoint = torch.load(checkpoint_path, map_location=device)


Loaded model parameters from /home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/cnn.pth
Loaded model parameters from /home/jovyan/S2S/Meta-NN/gefs_checkpoint/week2/cnn.pth
Loaded model parameters from /home/jovyan/S2S/Meta-NN/gefs_checkpoint/week34/cnn.pth


# ANN

## data

In [13]:
# week 1
X_train_ann_week1 = torch.tensor(week1_train_input, dtype=torch.float32)
y_train_ann_week1 = torch.tensor(week1_train_output, dtype=torch.float32)
X_val_ann_week1 = torch.tensor(week1_val_input, dtype=torch.float32)
y_val_ann_week1 = torch.tensor(week1_val_output, dtype=torch.float32)
X_test_ann_week1 = torch.tensor(week1_test_input, dtype=torch.float32)
y_test_ann_week1 = torch.tensor(week1_test_output, dtype=torch.float32)

# week 2
X_train_ann_week2 = torch.tensor(week2_train_input, dtype=torch.float32)
y_train_ann_week2 = torch.tensor(week2_train_output, dtype=torch.float32)
X_val_ann_week2 = torch.tensor(week2_val_input, dtype=torch.float32)
y_val_ann_week2 = torch.tensor(week2_val_output, dtype=torch.float32)
X_test_ann_week2 = torch.tensor(week2_test_input, dtype=torch.float32)
y_test_ann_week2 = torch.tensor(week2_test_output, dtype=torch.float32)

# week 34
X_train_ann_week34 = torch.tensor(week34_train_input, dtype=torch.float32)
y_train_ann_week34 = torch.tensor(week34_train_output, dtype=torch.float32)
X_val_ann_week34 = torch.tensor(week34_val_input, dtype=torch.float32)
y_val_ann_week34 = torch.tensor(week34_val_output, dtype=torch.float32)
X_test_ann_week34 = torch.tensor(week34_test_input, dtype=torch.float32)
y_test_ann_week34 = torch.tensor(week34_test_output, dtype=torch.float32)

## train

In [19]:
# time_steps, num_rows, num_cols = X_train_ann_week1.shape

# set_random_seed(42)

# for i in range(num_rows):
#     for j in range(num_cols):
        
#         X_train_loc = X_train_ann_week1[:, i, j].reshape(-1, 1)
#         y_train_loc = y_train_ann_week1[:, i, j].reshape(-1, 1)
#         X_val_loc = X_val_ann_week1[:, i, j].reshape(-1, 1)
#         y_val_loc = y_val_ann_week1[:, i, j].reshape(-1, 1)

#         if torch.isnan(X_train_loc).any() or torch.isnan(y_train_loc).any() or \
#            torch.isnan(X_val_loc).any() or torch.isnan(y_val_loc).any():
#             # print(f"Skipping location ({i}, {j}) due to NaN values")
#             continue
            
#         train_dataset = TensorDataset(X_train_loc, y_train_loc)
#         val_dataset = TensorDataset(X_val_loc, y_val_loc)
#         train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
#         val_loader = DataLoader(val_dataset, batch_size=32)

#         model = ANN()

#         criterion = nn.MSELoss()
#         optimizer = optim.Adam(model.parameters(), lr=0.01)

#         checkpoint_path = f'/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/ann/checkpoint_loc_{i}_{j}.pth'

#         ann_time = train_ann(model, train_loader, val_loader, criterion, optimizer, num_epochs=300, checkpoint_path=checkpoint_path, patience=10, scheduler_patience=5)

# print("Training completed for all locations.")

## Evaluation

In [18]:
pre_ann_week1 = np.full(X_test_ann_week1.shape, np.nan)
pre_ann_week2 = np.full(X_test_ann_week2.shape, np.nan)
pre_ann_week34 = np.full(X_test_ann_week34.shape, np.nan)

time_steps, num_rows, num_cols = X_train_ann_week1.shape

for i in range(num_rows):
    for j in range(num_cols):

        X_test_loc_week1 = X_test_ann_week1[:, i, j].reshape(-1, 1)
        y_test_loc_week1 = y_test_ann_week1[:, i, j].reshape(-1, 1)
        
        X_test_loc_week2 = X_test_ann_week2[:, i, j].reshape(-1, 1)
        y_test_loc_week2 = y_test_ann_week2[:, i, j].reshape(-1, 1)
        
        X_test_loc_week34 = X_test_ann_week34[:, i, j].reshape(-1, 1)
        y_test_loc_week34 = y_test_ann_week34[:, i, j].reshape(-1, 1)

        if torch.isnan(y_test_loc_week1).any():
            # print(f"Skipping location ({i}, {j}) due to NaN values in y_test")
            continue

        test_dataset_week1 = TensorDataset(X_test_loc_week1, y_test_loc_week1)
        test_loader_week1 = DataLoader(test_dataset_week1, batch_size=32)
        
        test_dataset_week2 = TensorDataset(X_test_loc_week2, y_test_loc_week2)
        test_loader_week2 = DataLoader(test_dataset_week2, batch_size=32)
        
        test_dataset_week34 = TensorDataset(X_test_loc_week34, y_test_loc_week34)
        test_loader_week34 = DataLoader(test_dataset_week34, batch_size=32)
        
        model_week1 = ANN()
        model_week2 = ANN()
        model_week34 = ANN()

        checkpoint_path_week1 = f'/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/ann/checkpoint_loc_{i}_{j}.pth'
        checkpoint_path_week2 = f'/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week2/ann/checkpoint_loc_{i}_{j}.pth'
        checkpoint_path_week34 = f'/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week34/ann/checkpoint_loc_{i}_{j}.pth'

        if os.path.exists(checkpoint_path_week1):
            model_week1.load_state_dict(torch.load(checkpoint_path_week1))
            # print(f"Checkpoint loaded from {checkpoint_path}")
        else:
            # print(f"No checkpoint found for location ({i}, {j}), skipping evaluation")
            continue
            
        if os.path.exists(checkpoint_path_week2):
            model_week2.load_state_dict(torch.load(checkpoint_path_week2))
            # print(f"Checkpoint loaded from {checkpoint_path}")
        else:
            # print(f"No checkpoint found for location ({i}, {j}), skipping evaluation")
            continue

        if os.path.exists(checkpoint_path_week34):
            model_week34.load_state_dict(torch.load(checkpoint_path_week34))
            # print(f"Checkpoint loaded from {checkpoint_path}")
        else:
            # print(f"No checkpoint found for location ({i}, {j}), skipping evaluation")
            continue
            
        predictions_week1 = evaluate_ann(model_week1, test_loader_week1)     
        pre_ann_week1[:, i, j] = predictions_week1.flatten()

        predictions_week2 = evaluate_ann(model_week2, test_loader_week2)     
        pre_ann_week2[:, i, j] = predictions_week2.flatten()

        predictions_week34 = evaluate_ann(model_week34, test_loader_week34)     
        pre_ann_week34[:, i, j] = predictions_week34.flatten()

In [16]:
mse_ann_week1 = ((pre_ann_week1 - week1_test_output)**2 ).mean(axis=0)
cc_ann_week1 = calc_spatial_correlation(pre_ann_week1, week1_test_output)
ss_ann_week1 = 1 - mse_ann_week1/mse_raw_week1

mse_ann_week2 = ((pre_ann_week2 - week2_test_output)**2 ).mean(axis=0)
cc_ann_week2 = calc_spatial_correlation(pre_ann_week2, week2_test_output)
ss_ann_week2 = 1 - mse_ann_week2/mse_raw_week2

mse_ann_week34 = ((pre_ann_week34 - week34_test_output)**2 ).mean(axis=0)
cc_ann_week34 = calc_spatial_correlation(pre_ann_week34, week34_test_output)
ss_ann_week34 = 1 - mse_ann_week34/mse_raw_week34

# Weighted Average

In [15]:
pre_lr_meta_week1 = perform_linear_regression(week1_train_input, week1_train_output, week1_val_input, week1_val_output)
pre_lr_meta_week2 = perform_linear_regression(week2_train_input, week2_train_output, week2_val_input, week2_val_output)
pre_lr_meta_week34 = perform_linear_regression(week34_train_input, week34_train_output, week34_val_input, week34_val_output)

pre_cnn_meta_week1 = evaluate_cnn(model_cnn, device, val_loader_cnn_week1, '/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/cnn.pth')
pre_cnn_meta_week2 = evaluate_cnn(model_cnn, device, val_loader_cnn_week2, '/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week2/cnn.pth')
pre_cnn_meta_week34 = evaluate_cnn(model_cnn, device, val_loader_cnn_week34, '/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week34/cnn.pth')

pre_ann_meta_week1 = np.full(X_val_ann_week1.shape, np.nan)
pre_ann_meta_week2 = np.full(X_val_ann_week2.shape, np.nan)
pre_ann_meta_week34 = np.full(X_val_ann_week34.shape, np.nan)

time_steps, num_rows, num_cols = X_val_ann_week1.shape

# Loop through each location
for i in range(num_rows):
    for j in range(num_cols):
        # print(f"Evaluating model for location ({i}, {j})")
        
        # Extract data for this location
        X_val_loc_week1 = X_val_ann_week1[:, i, j].reshape(-1, 1)
        y_val_loc_week1 = y_val_ann_week1[:, i, j].reshape(-1, 1)
        
        X_val_loc_week2 = X_val_ann_week2[:, i, j].reshape(-1, 1)
        y_val_loc_week2 = y_val_ann_week2[:, i, j].reshape(-1, 1)
        
        X_val_loc_week34 = X_val_ann_week34[:, i, j].reshape(-1, 1)
        y_val_loc_week34 = y_val_ann_week34[:, i, j].reshape(-1, 1)

        # Check for NaN values in y_val
        if torch.isnan(y_val_loc_week1).any():
            # print(f"Skipping location ({i}, {j}) due to NaN values in y_val")
            continue

        # Create DataLoader
        val_dataset_week1 = TensorDataset(X_val_loc_week1, y_val_loc_week1)
        val_loader_week1 = DataLoader(val_dataset_week1, batch_size=32)
        
        val_dataset_week2 = TensorDataset(X_val_loc_week2, y_val_loc_week2)
        val_loader_week2 = DataLoader(val_dataset_week2, batch_size=32)
        
        val_dataset_week34 = TensorDataset(X_val_loc_week34, y_val_loc_week34)
        val_loader_week34 = DataLoader(val_dataset_week34, batch_size=32)
        
        model_week1 = ANN()
        model_week2 = ANN()
        model_week34 = ANN()

        # Load the checkpoint
        checkpoint_path_week1 = f'/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/ann/checkpoint_loc_{i}_{j}.pth'
        checkpoint_path_week2 = f'/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week2/ann/checkpoint_loc_{i}_{j}.pth'
        checkpoint_path_week34 = f'/home/jovyan/S2S/Meta-NN/gefs_checkpoint/week34/ann/checkpoint_loc_{i}_{j}.pth'

        if os.path.exists(checkpoint_path_week1):
            model_week1.load_state_dict(torch.load(checkpoint_path_week1))
            # print(f"Checkpoint loaded from {checkpoint_path}")
        else:
            # print(f"No checkpoint found for location ({i}, {j}), skipping evaluation")
            continue
            
        if os.path.exists(checkpoint_path_week2):
            model_week2.load_state_dict(torch.load(checkpoint_path_week2))
            # print(f"Checkpoint loaded from {checkpoint_path}")
        else:
            # print(f"No checkpoint found for location ({i}, {j}), skipping evaluation")
            continue

        if os.path.exists(checkpoint_path_week34):
            model_week34.load_state_dict(torch.load(checkpoint_path_week34))
            # print(f"Checkpoint loaded from {checkpoint_path}")
        else:
            # print(f"No checkpoint found for location ({i}, {j}), skipping evaluation")
            continue
            
        predictions_meta_week1 = evaluate_ann(model_week1, val_loader_week1)     
        pre_ann_meta_week1[:, i, j] = predictions_meta_week1.flatten()

        predictions_meta_week2 = evaluate_ann(model_week2, val_loader_week2)     
        pre_ann_meta_week2[:, i, j] = predictions_meta_week2.flatten()

        predictions_meta_week34 = evaluate_ann(model_week34, val_loader_week34)     
        pre_ann_meta_week34[:, i, j] = predictions_meta_week34.flatten()

  
  checkpoint = torch.load(checkpoint_path, map_location=device)


Loaded model parameters from /home/jovyan/S2S/Meta-NN/gefs_checkpoint/week1/cnn.pth
Loaded model parameters from /home/jovyan/S2S/Meta-NN/gefs_checkpoint/week2/cnn.pth
Loaded model parameters from /home/jovyan/S2S/Meta-NN/gefs_checkpoint/week34/cnn.pth


  model_week1.load_state_dict(torch.load(checkpoint_path_week1))
  model_week2.load_state_dict(torch.load(checkpoint_path_week2))
  model_week34.load_state_dict(torch.load(checkpoint_path_week34))


In [17]:
preds_week1 = {
    'CNN': pre_cnn_meta_week1,
    'ANN': pre_ann_meta_week1,
    'LR': pre_lr_meta_week1  
}

preds_week2 = {
    'CNN': pre_cnn_meta_week2,
    'ANN': pre_ann_meta_week2,
    'LR': pre_lr_meta_week2
}

preds_week34 = {
    'CNN': pre_cnn_meta_week34,
    'ANN': pre_ann_meta_week34,
    'LR': pre_lr_meta_week34
}


warnings.filterwarnings('ignore')

ensemble_weighting = BayesianEnsembleWeighting(n_models=3)

for period, (y_true, preds) in [
    ("Week 1", (week1_val_output, preds_week1)),
    ("Week 2", (week2_val_output, preds_week2)),
    ("Week 3-4", (week34_val_output, preds_week34))
]:
    weights, opt_time = ensemble_weighting.optimize_weights(y_true, preds)
    print(f"\nOptimal weights for {period}:")
    for model, weight in weights.items():
        print(f"{model}: {weight:.3f}")
    print(f"Optimization time: {int(opt_time//3600)}h {int((opt_time%3600)//60)}m {int(opt_time%60)}s")


Optimal weights for Week 1:
CNN: 0.831
ANN: 0.169
LR: 0.001
Optimization time: 0h 0m 8s

Optimal weights for Week 2:
CNN: 0.388
ANN: 0.609
LR: 0.003
Optimization time: 0h 0m 8s

Optimal weights for Week 3-4:
CNN: 0.630
ANN: 0.068
LR: 0.302
Optimization time: 0h 0m 8s


In [19]:
pre_ensemble_week1 = 0.001 * pre_lr_week1 + 0.181 * pre_ann_week1 + 0.818 * pre_cnn_week1
pre_ensemble_week2 = 0.009 * pre_lr_week2 + 0.635 * pre_ann_week2 + 0.356 * pre_cnn_week2
pre_ensemble_week34 = 0.302 * pre_lr_week34 + 0.067 * pre_ann_week34 + 0.632 * pre_cnn_week34


mse_ensemble_week1 = ((pre_ensemble_week1 - week1_test_output)**2 ).mean(axis=0)
cc_ensemble_week1 = calc_spatial_correlation(pre_ensemble_week1, week1_test_output)
ss_ensemble_week1 = 1 - mse_ensemble_week1/mse_raw_week1

mse_ensemble_week2 = ((pre_ensemble_week2 - week2_test_output)**2 ).mean(axis=0)
cc_ensemble_week2 = calc_spatial_correlation(pre_ensemble_week2, week2_test_output)
ss_ensemble_week2 = 1 - mse_ensemble_week2/mse_raw_week2

mse_ensemble_week34 = ((pre_ensemble_week34 - week34_test_output)**2 ).mean(axis=0)
cc_ensemble_week34 = calc_spatial_correlation(pre_ensemble_week34, week34_test_output)
ss_ensemble_week34 = 1 - mse_ensemble_week34/mse_raw_week34