In [1]:
import os
import sys
maindir = os.getcwd()
sys.path.append(maindir+"/src")

In [2]:
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt

from preprocessing import data_processing, compute_anomalies_and_scalers, \
                            compute_forced_response, \
                            numpy_to_torch, rescale_and_merge_training_and_test_sets, \
                            rescale_training_and_test_sets


from plot_tools import plot_gt_vs_pred, animation_gt_vs_pred
from leave_one_out import leave_one_out_single, leave_one_out_procedure
from cross_validation import cross_validation_procedure

In [3]:
############### Load climate model raw data for SST
with open('data/ssp585_time_series.pkl', 'rb') as f:
    data = pickle.load(f)

###################### Load longitude and latitude 
with open('data/lon.npy', 'rb') as f:
    lon = np.load(f)

with open('data/lat.npy', 'rb') as f:
    lat = np.load(f)

# define grid (+ croping for latitude > 60)
lat_grid, lon_grid = np.meshgrid(lat[lat<=60], lon, indexing='ij')

lat_size = lat_grid.shape[0]
lon_size = lon_grid.shape[1]

In [4]:
# define pytorch precision
dtype = torch.float32

data_processed, notnan_idx, nan_idx = data_processing(data, lon, lat,max_models=100)
x, means, vars = compute_anomalies_and_scalers(data_processed, lon_size, lat_size, nan_idx, time_period=34)
y = compute_forced_response(data_processed, lon_size, lat_size, nan_idx, time_period=34)

x,y, means, vars = numpy_to_torch(x,y,means,vars, dtype=dtype)

  means[m] = np.nanmean(data_reshaped[m],axis=0)
  vars[m] = np.nanvar(data_reshaped[m],axis=0)
  mean_spatial_ensemble = np.nanmean(y_tmp,axis=0)


In [44]:
m0 = 'ICON-ESM-LR'

training_models, x_train, y_train, x_test, y_test = rescale_and_merge_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)
training_models, x_rescaled, y_rescaled = rescale_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)

In [45]:
def stack_runs_for_each_model(models,x,y, dtype=dtype):
    """Stack all ensemble members of all models in a single tensor.

       Args:

       Return:
    """

    for idx_m,m in enumerate(models):
        if idx_m == 0:
            x_stacked = x[m]
            y_stacked = y[m]
        else:   

            x_stacked = torch.cat((x_stacked, x[m]), dim=0)
            y_stacked = torch.cat((y_stacked, x[m]), dim=0)

    return x_stacked, y_stacked


In [46]:
x_tmp, y_tmp = stack_runs_for_each_model(training_models,x_rescaled,y_rescaled)

# Construct variational autoencoder 

In [47]:
import torch
import torch.nn as nn
import torch.optim as optim

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.hidden_to_mean = nn.Linear(hidden_dim, latent_dim)
        self.hidden_to_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        _, (h, _) = self.lstm(x)
        h = h[-1]
        mean = self.hidden_to_mean(h)
        logvar = self.hidden_to_logvar(h)
        return mean, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.hidden_to_output = nn.Linear(hidden_dim, output_dim)

    def forward(self, z, seq_len):
        h = self.latent_to_hidden(z).unsqueeze(1).repeat(1, seq_len, 1)
        out, _ = self.lstm(h)
        out = self.hidden_to_output(out)
        return out

class TimeVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(TimeVAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, output_dim)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        recon_x = self.decoder(z, x.size(1))
        return recon_x, mean, logvar

def loss_function(recon_x, x, mean, logvar):
    recon_loss = nn.MSELoss()(recon_x, x)
    # recon_loss = torch.sum((recon_loss - x)**2)  
    kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return recon_loss 

# Example usage
input_dim = len(notnan_idx)  # Number of features in the time series
hidden_dim = 64
latent_dim = 16
output_dim = len(notnan_idx)  # Same as input_dim

# number of epochs
nbEpochs = 300


model = TimeVAE(input_dim, hidden_dim, latent_dim, output_dim)
optimizer = optim.SGD(model.parameters(), lr=1e-3)


model.train()
for epoch in range(nbEpochs):
    optimizer.zero_grad()
    recon_x, mean, logvar = model(x_rescaled[m0][:,:,notnan_idx])
    loss = loss_function(recon_x, y_rescaled[m0][:,:,notnan_idx], mean, logvar)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")


Epoch 1, Loss: 2.128553628921509
Epoch 2, Loss: 2.12736439704895
Epoch 3, Loss: 2.124907970428467
Epoch 4, Loss: 2.1261749267578125
Epoch 5, Loss: 2.1267998218536377
Epoch 6, Loss: 2.1280438899993896
Epoch 7, Loss: 2.126358985900879
Epoch 8, Loss: 2.123724937438965
Epoch 9, Loss: 2.1247546672821045
Epoch 10, Loss: 2.126061201095581
Epoch 11, Loss: 2.1261165142059326
Epoch 12, Loss: 2.128704309463501
Epoch 13, Loss: 2.1261889934539795
Epoch 14, Loss: 2.1283042430877686
Epoch 15, Loss: 2.1265785694122314
Epoch 16, Loss: 2.1263999938964844
Epoch 17, Loss: 2.12776517868042
Epoch 18, Loss: 2.127451181411743
Epoch 19, Loss: 2.126783609390259
Epoch 20, Loss: 2.1262452602386475
Epoch 21, Loss: 2.1248250007629395
Epoch 22, Loss: 2.1260480880737305
Epoch 23, Loss: 2.1263158321380615
Epoch 24, Loss: 2.125706195831299
Epoch 25, Loss: 2.1270270347595215
Epoch 26, Loss: 2.1268861293792725
Epoch 27, Loss: 2.1255388259887695
Epoch 28, Loss: 2.1252901554107666
Epoch 29, Loss: 2.12726092338562
Epoch 30,

In [48]:
y_pred_vae = model(x_rescaled[m0][:,:,notnan_idx])[0]


In [49]:
from algorithms import ridge_regression, ridge_regression_low_rank, low_rank_projection, \
                        prediction, train_robust_weights_model, compute_weights

In [50]:
lambda_tmp = 10.0

# compute the big matrix X and Y
training_models, x_train_merged, y_train_merged, x_test_merged, y_test_merged = rescale_and_merge_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)

# compute ridge regressor
w_ridge = torch.zeros(lon_size*lat_size,lon_size*lat_size,dtype=dtype)
w_ridge[np.ix_(notnan_idx,notnan_idx)] = ridge_regression(x_train_merged[:,notnan_idx], y_train_merged[:,notnan_idx], lambda_=lambda_tmp, dtype=dtype)

x_test_tmp = x[m0]
y_test_tmp = y[m0]

# ridge
y_pred_ridge = prediction(x_test_tmp, w_ridge,notnan_idx, nan_idx)

In [51]:
import torch
import torch.nn as nn
import torch.optim as optim

class Embedder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(Embedder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        h, _ = self.lstm(x)
        h = self.fc(h)
        return h

class Recovery(nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layers):
        super(Recovery, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, h):
        h, _ = self.lstm(h)
        x_tilde = self.fc(h)
        return x_tilde

class Generator(nn.Module):
    def __init__(self, z_dim, hidden_dim, num_layers):
        super(Generator, self).__init__()
        self.lstm = nn.LSTM(z_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, z):
        h, _ = self.lstm(z)
        h = self.fc(h)
        return h

class Supervisor(nn.Module):
    def __init__(self, hidden_dim, num_layers):
        super(Supervisor, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h):
        h, _ = self.lstm(h)
        h = self.fc(h)
        return h

class Discriminator(nn.Module):
    def __init__(self, hidden_dim, num_layers):
        super(Discriminator, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, h):
        h, _ = self.lstm(h)
        y_hat = self.fc(h)
        return y_hat

class TimeGAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim, num_layers):
        super(TimeGAN, self).__init__()
        self.embedder = Embedder(input_dim, hidden_dim, num_layers)
        self.recovery = Recovery(hidden_dim, input_dim, num_layers)
        self.generator = Generator(z_dim, hidden_dim, num_layers)
        self.supervisor = Supervisor(hidden_dim, num_layers)
        self.discriminator = Discriminator(hidden_dim, num_layers)

    def forward(self, x, z):
        h = self.embedder(x)
        x_tilde = self.recovery(h)
        e_hat = self.generator(z)
        h_hat_supervise = self.supervisor(h)
        y_fake = self.discriminator(e_hat)
        y_real = self.discriminator(h)
        return x_tilde, h, e_hat, h_hat_supervise, y_fake, y_real


In [52]:
# we try with a different loss function
nbEpochs = 500
input_dim = len(notnan_idx)  # Number of features in the time series
hidden_dim = 24
z_dim = 8
num_layers = 3

z = torch.randn(x_tmp.shape[0], 34, z_dim)  # Latent space input

model = TimeGAN(input_dim, hidden_dim, z_dim, num_layers)
optimizer_embedder = optim.Adam(model.parameters(), lr=1e-3)
optimizer_recovery = optim.Adam(model.parameters(), lr=1e-3)
optimizer_generator = optim.Adam(model.parameters(), lr=1e-3)
optimizer_supervisor = optim.Adam(model.parameters(), lr=1e-3)
optimizer_discriminator = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(nbEpochs):
    optimizer_embedder.zero_grad()
    optimizer_recovery.zero_grad()
    optimizer_generator.zero_grad()
    optimizer_supervisor.zero_grad()
    optimizer_discriminator.zero_grad()

    # Forward pass
    x_tilde, h, e_hat, h_hat_supervise, y_fake, y_real = model(x_tmp[:,:,notnan_idx], z)
    
    # Compute losses
    reconstruction_loss = nn.MSELoss()(x_tilde, y_tmp[:,:,notnan_idx])
    supervised_loss = nn.MSELoss()(h_hat_supervise[:, 1:, :], h[:, :-1, :])
    real_loss = nn.BCEWithLogitsLoss()(y_real, torch.ones_like(y_real))
    fake_loss = nn.BCEWithLogitsLoss()(y_fake, torch.zeros_like(y_fake))
    discriminator_loss = real_loss + fake_loss
    generator_loss = nn.BCEWithLogitsLoss()(y_fake, torch.ones_like(y_fake))
    embedding_loss = nn.MSELoss()(e_hat, h)

    # Total losses
    embedder_recovery_loss = reconstruction_loss + embedding_loss
    supervisor_loss = supervised_loss
    generator_total_loss = generator_loss + supervised_loss + embedding_loss
    discriminator_total_loss = discriminator_loss

    # Backward pass and optimization
    total_loss = embedder_recovery_loss + supervisor_loss + generator_total_loss + discriminator_total_loss
    total_loss.backward()

    # Backward pass and optimization
    # embedder_recovery_loss.backward(retain_graph=True)
    optimizer_embedder.step()
    optimizer_recovery.step()

    # supervisor_loss.backward(retain_graph=True)
    optimizer_supervisor.step()

    # generator_total_loss.backward(retain_graph=True)
    optimizer_generator.step()

    # discriminator_total_loss.backward()
    optimizer_discriminator.step()

    # print(f"Epoch {epoch + 1}, Losses: E/R: {embedder_recovery_loss.item()}, S: {supervisor_loss.item()}, G: {generator_total_loss.item()}, D: {discriminator_total_loss.item()}")
    print(f"Epoch {epoch + 1}, Loss: {total_loss.item()}")

Epoch 1, Loss: 3.6527366638183594
Epoch 2, Loss: 3.5717344284057617
Epoch 3, Loss: 3.5051472187042236
Epoch 4, Loss: 3.447333335876465
Epoch 5, Loss: 3.400808095932007
Epoch 6, Loss: 3.370743989944458
Epoch 7, Loss: 3.373318672180176
Epoch 8, Loss: 3.3766720294952393
Epoch 9, Loss: 3.361311435699463
Epoch 10, Loss: 3.3446106910705566
Epoch 11, Loss: 3.332108974456787
Epoch 12, Loss: 3.319951057434082
Epoch 13, Loss: 3.3024516105651855
Epoch 14, Loss: 3.2752745151519775
Epoch 15, Loss: 3.239579677581787
Epoch 16, Loss: 3.200315475463867
Epoch 17, Loss: 3.1603174209594727
Epoch 18, Loss: 3.1197032928466797
Epoch 19, Loss: 3.067256450653076
Epoch 20, Loss: 3.0069432258605957
Epoch 21, Loss: 2.9460573196411133
Epoch 22, Loss: 2.889329433441162
Epoch 23, Loss: 2.8391470909118652
Epoch 24, Loss: 2.7968831062316895
Epoch 25, Loss: 2.7646162509918213
Epoch 26, Loss: 2.735165596008301
Epoch 27, Loss: 2.713688373565674
Epoch 28, Loss: 2.6954362392425537
Epoch 29, Loss: 2.680443525314331
Epoch 30

In [53]:
y_pred_gan = model(x_rescaled[m0][:,:,notnan_idx], z)[0]

In [54]:
# compare the rmse of the three methods
rmse_ridge = torch.sqrt(torch.nanmean((y_test_tmp - y_pred_ridge)**2))
rmse_vae = torch.sqrt(torch.mean((y_test_tmp[:,:,notnan_idx] - y_pred_vae)**2))
rmse_gan = torch.sqrt(torch.mean((y_test_tmp[:,:,notnan_idx] - y_pred_gan)**2))

print(f"RMSE Ridge: {rmse_ridge}")
print(f"RMSE VAE: {rmse_vae}")
print(f"RMSE GAN: {rmse_gan}")

RMSE Ridge: 0.008337395265698433
RMSE VAE: 19.75765037536621
RMSE GAN: 19.878334045410156
