In [1]:
import os
import sys
maindir = os.getcwd()
sys.path.append(maindir+"/src")

In [2]:
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt

from preprocessing import data_processing, compute_anomalies_and_scalers, \
                            compute_forced_response, \
                            numpy_to_torch, rescale_and_merge_training_and_test_sets, \
                            rescale_training_and_test_sets


from plot_tools import plot_gt_vs_pred, animation_gt_vs_pred
from leave_one_out import leave_one_out_single, leave_one_out_procedure
from cross_validation import cross_validation_procedure

In [55]:
############### Load climate model raw data for SST
with open('data/ssp585_time_series.pkl', 'rb') as f:
    data = pickle.load(f)

###################### Load longitude and latitude 
with open('data/lon.npy', 'rb') as f:
    lon = np.load(f)

with open('data/lat.npy', 'rb') as f:
    lat = np.load(f)

# define grid (+ croping for latitude > 60)
lat_grid, lon_grid = np.meshgrid(lat[lat<=60], lon, indexing='ij')

lat_size = lat_grid.shape[0]
lon_size = lon_grid.shape[1]

In [56]:
# define pytorch precision
dtype = torch.float32

data_processed, notnan_idx, nan_idx = data_processing(data, lon, lat,max_models=100)
x, means, vars = compute_anomalies_and_scalers(data_processed, lon_size, lat_size, nan_idx, time_period=34)
y = compute_forced_response(data_processed, lon_size, lat_size, nan_idx, time_period=34)

x,y, means, vars = numpy_to_torch(x,y,means,vars, dtype=dtype)

############### REMOVE model 'GISS-E2-2-G' from the dataset ################
x.pop('GISS-E2-2-G')
y.pop('GISS-E2-2-G')
means.pop('GISS-E2-2-G')
vars.pop('GISS-E2-2-G')

  means[m] = np.nanmean(data_reshaped[m],axis=(0,1))
  vars[m] = np.nanvar(data_reshaped[m],axis=(0,1))
  mean_spatial_ensemble = np.nanmean(y_tmp,axis=0)


tensor([   nan,    nan,    nan,  ..., 0.1041, 0.1140,    nan])

In [57]:
# m0 = 'ICON-ESM-LR'
# m0 = 'GISS-E2-2-G'
m0 = 'GISS-E2-2-H'

training_models, x_train, y_train, x_test, y_test = rescale_and_merge_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)
training_models, x_rescaled, y_rescaled = rescale_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)

In [58]:
def stack_runs_for_each_model(models,x,y, dtype=dtype):
    """Stack all ensemble members of all models in a single tensor.

       Args:

       Return:
    """

    for idx_m,m in enumerate(models):
        if idx_m == 0:
            x_stacked = x[m]
            y_stacked = y[m]
        else:   

            x_stacked = torch.cat((x_stacked, x[m]), dim=0)
            y_stacked = torch.cat((y_stacked, x[m]), dim=0)

    return x_stacked, y_stacked


In [59]:
x_tmp, y_tmp = stack_runs_for_each_model(training_models,x_rescaled,y_rescaled)

# Construct variational autoencoder 

In [60]:
import torch
import torch.nn as nn
import torch.optim as optim

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.hidden_to_mean = nn.Linear(hidden_dim, latent_dim)
        self.hidden_to_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        _, (h, _) = self.lstm(x)
        h = h[-1]
        mean = self.hidden_to_mean(h)
        logvar = self.hidden_to_logvar(h)
        return mean, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.hidden_to_output = nn.Linear(hidden_dim, output_dim)

    def forward(self, z, seq_len):
        h = self.latent_to_hidden(z).unsqueeze(1).repeat(1, seq_len, 1)
        out, _ = self.lstm(h)
        out = self.hidden_to_output(out)
        return out

class TimeVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(TimeVAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, output_dim)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        recon_x = self.decoder(z, x.size(1))
        return recon_x, mean, logvar

def loss_function(recon_x, x, mean, logvar):
    recon_loss = nn.MSELoss()(recon_x, x)
    # recon_loss = torch.sum((recon_loss - x)**2)  
    kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return recon_loss 

# Example usage
input_dim = len(notnan_idx)  # Number of features in the time series
hidden_dim = 64
latent_dim = 32
output_dim = len(notnan_idx)  # Same as input_dim

# number of epochs
nbEpochs = 300


model_vae = TimeVAE(input_dim, hidden_dim, latent_dim, output_dim)
optimizer = optim.Adam(model_vae.parameters(), lr=1e-3)


model_vae.train()
for epoch in range(nbEpochs):
    optimizer.zero_grad()
    recon_x, mean, logvar = model_vae(x_tmp[:,:,notnan_idx])
    loss = loss_function(recon_x, y_tmp[:,:,notnan_idx], mean, logvar)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")


Epoch 1, Loss: 1.0082672834396362
Epoch 2, Loss: 1.0065789222717285
Epoch 3, Loss: 1.0051600933074951
Epoch 4, Loss: 1.0038495063781738
Epoch 5, Loss: 1.0030747652053833
Epoch 6, Loss: 1.0022313594818115
Epoch 7, Loss: 1.0013911724090576
Epoch 8, Loss: 1.0006290674209595
Epoch 9, Loss: 0.9999138712882996
Epoch 10, Loss: 0.9993199706077576
Epoch 11, Loss: 0.9987760782241821
Epoch 12, Loss: 0.9982327818870544
Epoch 13, Loss: 0.997759997844696
Epoch 14, Loss: 0.9973581433296204
Epoch 15, Loss: 0.9969742298126221
Epoch 16, Loss: 0.9966176152229309
Epoch 17, Loss: 0.9963246583938599
Epoch 18, Loss: 0.9959900379180908
Epoch 19, Loss: 0.9958310127258301
Epoch 20, Loss: 0.9955629706382751
Epoch 21, Loss: 0.9953388571739197
Epoch 22, Loss: 0.995152473449707
Epoch 23, Loss: 0.9949284195899963
Epoch 24, Loss: 0.9946921467781067
Epoch 25, Loss: 0.9943747520446777
Epoch 26, Loss: 0.9940397143363953
Epoch 27, Loss: 0.9937264323234558
Epoch 28, Loss: 0.9932270646095276
Epoch 29, Loss: 0.9927393198013

In [61]:
from algorithms import ridge_regression, ridge_regression_low_rank, train_robust_weights, train_robust_weights_trace_norm,\
                        prediction, compute_weights

In [62]:
lambda_tmp = 100.0

# compute the big matrix X and Y
training_models, x_train_merged, y_train_merged, x_test_merged, y_test_merged = rescale_and_merge_training_and_test_sets(m0,x,y,means,vars,dtype=dtype)

# compute ridge regressor
w_ridge = torch.zeros(lon_size*lat_size,lon_size*lat_size,dtype=dtype)
w_ridge[np.ix_(notnan_idx,notnan_idx)] = ridge_regression(x_train_merged[:,notnan_idx], y_train_merged[:,notnan_idx], lambda_=lambda_tmp, dtype=dtype)

x_test_tmp = x_rescaled[m0]
y_test_tmp = y_rescaled[m0]

# ridge
y_pred_ridge = prediction(x_test_tmp, w_ridge,notnan_idx, nan_idx)

In [63]:
import torch
import torch.nn as nn
import torch.optim as optim

class Embedder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(Embedder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        h, _ = self.lstm(x)
        h = self.fc(h)
        return h

class Recovery(nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layers):
        super(Recovery, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, h):
        h, _ = self.lstm(h)
        x_tilde = self.fc(h)
        return x_tilde

class Generator(nn.Module):
    def __init__(self, z_dim, hidden_dim, num_layers):
        super(Generator, self).__init__()
        self.lstm = nn.LSTM(z_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, z):
        h, _ = self.lstm(z)
        h = self.fc(h)
        return h

class Supervisor(nn.Module):
    def __init__(self, hidden_dim, num_layers):
        super(Supervisor, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h):
        h, _ = self.lstm(h)
        h = self.fc(h)
        return h

class Discriminator(nn.Module):
    def __init__(self, hidden_dim, num_layers):
        super(Discriminator, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, h):
        h, _ = self.lstm(h)
        y_hat = self.fc(h)
        return y_hat

class TimeGAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim, num_layers):
        super(TimeGAN, self).__init__()
        self.embedder = Embedder(input_dim, hidden_dim, num_layers)
        self.recovery = Recovery(hidden_dim, input_dim, num_layers)
        self.generator = Generator(z_dim, hidden_dim, num_layers)
        self.supervisor = Supervisor(hidden_dim, num_layers)
        self.discriminator = Discriminator(hidden_dim, num_layers)

    def forward(self, x, z):
        h = self.embedder(x)
        x_tilde = self.recovery(h)
        e_hat = self.generator(z)
        h_hat_supervise = self.supervisor(h)
        y_fake = self.discriminator(e_hat)
        y_real = self.discriminator(h)
        return x_tilde, h, e_hat, h_hat_supervise, y_fake, y_real


In [64]:
# we try with a different loss function
nbEpochs = 100
input_dim = len(notnan_idx)  # Number of features in the time series
hidden_dim = 8
z_dim = 8
num_layers = 5

z = torch.randn(x_tmp.shape[0], 34, z_dim)  # Latent space input

model = TimeGAN(input_dim, hidden_dim, z_dim, num_layers)
optimizer_embedder = optim.Adam(model.parameters(), lr=1e-3)
optimizer_recovery = optim.Adam(model.parameters(), lr=1e-3)
optimizer_generator = optim.Adam(model.parameters(), lr=1e-3)
optimizer_supervisor = optim.Adam(model.parameters(), lr=1e-3)
optimizer_discriminator = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(nbEpochs):
    optimizer_embedder.zero_grad()
    optimizer_recovery.zero_grad()
    optimizer_generator.zero_grad()
    optimizer_supervisor.zero_grad()
    optimizer_discriminator.zero_grad()

    # Forward pass
    x_tilde, h, e_hat, h_hat_supervise, y_fake, y_real = model(x_tmp[:,:,notnan_idx], z)
    
    # Compute losses
    reconstruction_loss = nn.MSELoss()(x_tilde, y_tmp[:,:,notnan_idx])
    supervised_loss = nn.MSELoss()(h_hat_supervise[:, 1:, :], h[:, :-1, :])
    real_loss = nn.BCEWithLogitsLoss()(y_real, torch.ones_like(y_real))
    fake_loss = nn.BCEWithLogitsLoss()(y_fake, torch.zeros_like(y_fake))
    discriminator_loss = real_loss + fake_loss
    generator_loss = nn.BCEWithLogitsLoss()(y_fake, torch.ones_like(y_fake))
    embedding_loss = nn.MSELoss()(e_hat, h)

    # Total losses
    embedder_recovery_loss = reconstruction_loss + embedding_loss
    supervisor_loss = supervised_loss
    generator_total_loss = generator_loss + supervised_loss + embedding_loss
    discriminator_total_loss = discriminator_loss

    # Backward pass and optimization
    total_loss = embedder_recovery_loss + supervisor_loss + generator_total_loss + discriminator_total_loss
    total_loss.backward()

    # Backward pass and optimization
    # embedder_recovery_loss.backward(retain_graph=True)
    optimizer_embedder.step()
    optimizer_recovery.step()

    # supervisor_loss.backward(retain_graph=True)
    optimizer_supervisor.step()

    # generator_total_loss.backward(retain_graph=True)
    optimizer_generator.step()

    # discriminator_total_loss.backward()
    optimizer_discriminator.step()

    # print(f"Epoch {epoch + 1}, Losses: E/R: {embedder_recovery_loss.item()}, S: {supervisor_loss.item()}, G: {generator_total_loss.item()}, D: {discriminator_total_loss.item()}")
    print(f"Epoch {epoch + 1}, Loss: {total_loss.item()}")

Epoch 1, Loss: 3.646031379699707
Epoch 2, Loss: 3.5569403171539307
Epoch 3, Loss: 3.4789774417877197
Epoch 4, Loss: 3.4108266830444336
Epoch 5, Loss: 3.351388931274414
Epoch 6, Loss: 3.2995777130126953
Epoch 7, Loss: 3.254565954208374
Epoch 8, Loss: 3.21575927734375
Epoch 9, Loss: 3.1826531887054443
Epoch 10, Loss: 3.1543169021606445
Epoch 11, Loss: 3.129316806793213
Epoch 12, Loss: 3.1063408851623535
Epoch 13, Loss: 3.0846734046936035
Epoch 14, Loss: 3.0641918182373047
Epoch 15, Loss: 3.0450797080993652
Epoch 16, Loss: 3.0274851322174072
Epoch 17, Loss: 3.011363983154297
Epoch 18, Loss: 2.9965219497680664
Epoch 19, Loss: 2.98275089263916
Epoch 20, Loss: 2.9699482917785645
Epoch 21, Loss: 2.9581432342529297
Epoch 22, Loss: 2.9474759101867676
Epoch 23, Loss: 2.938159465789795
Epoch 24, Loss: 2.9304099082946777
Epoch 25, Loss: 2.92441463470459
Epoch 26, Loss: 2.920304536819458
Epoch 27, Loss: 2.9180703163146973
Epoch 28, Loss: 2.9174654483795166
Epoch 29, Loss: 2.9179906845092773
Epoch 3

### TimeGAN and TimeVAE predictions

In [65]:
y_pred_gan = torch.zeros(x_rescaled[m0].shape, dtype=dtype)
y_pred_gan[:,:,notnan_idx] = model(x_rescaled[m0][:,:,notnan_idx], z)[0]
y_pred_gan[:,:,nan_idx] = float('nan')


y_pred_vae = torch.zeros(y_test_tmp.shape, dtype=dtype)
y_pred_vae[:,:,notnan_idx] = model_vae(x_rescaled[m0][:,:,notnan_idx])[0]
y_pred_vae[:,:,nan_idx] = float('nan')

In [67]:
# compare the rmse of the three methods
rmse_ridge = torch.nanmean((y_test_tmp - y_pred_ridge)**2)
rmse_vae = torch.nanmean((y_test_tmp - y_pred_vae)**2)
rmse_gan = torch.nanmean((y_test_tmp - y_pred_gan)**2)

print(f"RMSE Ridge: {rmse_ridge}")
print(f"RMSE VAE: {rmse_vae}")
print(f"RMSE GAN: {rmse_gan}")

RMSE Ridge: 3.877584934234619
RMSE VAE: 0.38371947407722473
RMSE GAN: 0.1861359030008316


In [None]:
from matplotlib import animation

time_period=33
fmax = 2.0
fmin = -1.0


plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()


plt.close('all')
fig0 = plt.figure(figsize=(24,16))

ax0 = fig0.add_subplot(2, 3, 1)        
ax0.set_title(r'Groundtruth', size=7,pad=3.0)
ax0.set_xlabel(r'x', size=7)
ax0.set_ylabel(r'y', size=7)

ax1 = fig0.add_subplot(2, 3, 2)        
ax1.set_title(r'Ridge regression', size=7,pad=3.0)
ax1.set_xlabel(r'x', size=7)
ax1.set_ylabel(r'y', size=7)

ax2 = fig0.add_subplot(2, 3, 3)        
ax2.set_title(r'VAE predictions', size=7,pad=3.0)
ax2.set_xlabel(r'x', size=7)
ax2.set_ylabel(r'y', size=7)

ax3 = fig0.add_subplot(2, 3, 5)        
ax3.set_title(r'GAN predictions', size=7,pad=3.0)
ax3.set_xlabel(r'x', size=7)
ax3.set_ylabel(r'y', size=7)

# ax4 = fig0.add_subplot(3, 3, 6)        
# ax4.set_title(r'Low rank Robust regression', size=7,pad=3.0)
# ax4.set_xlabel(r'x', size=7)
# ax4.set_ylabel(r'y', size=7)

# ax5 = fig0.add_subplot(3, 3, 8)
# ax5.set_title(r'Trace norm', size=7,pad=3.0)
# ax5.set_xlabel(r'x', size=7)
# ax5.set_ylabel(r'y', size=7)

# ax6 = fig0.add_subplot(3, 3, 9)
# ax6.set_title(r'Robust trace norm', size=7,pad=3.0)
# ax6.set_xlabel(r'x', size=7)
# ax6.set_ylabel(r'y', size=7)

# get first run of the test set
idx_run = 0
y_to_plot_target = y_test_tmp[idx_run,0,:].detach().numpy().reshape(lat_size,lon_size)
y_to_plot_pred_ridge = y_pred_ridge[idx_run,0,:].detach().numpy().reshape(lat_size,lon_size) # ridge
y_to_plot_pred_vae = y_pred_vae[idx_run,0,:].detach().numpy().reshape(lat_size,lon_size) # ridge low rank
y_to_plot_pred_gan = y_pred_gan[idx_run,0,:].detach().numpy().reshape(lat_size,lon_size) # robust 
# y_to_plot_pred_robust_lr = y_pred_robust_lr[idx_run,0,:].detach().numpy().reshape(lat_size,lon_size) # robust low rank
# y_to_plot_pred_trace = y_pred_trace[idx_run,0,:].detach().numpy().reshape(lat_size,lon_size) # trace norm
# y_to_plot_pred_robust_trace = y_pred_robust_trace[idx_run,0,:].detach().numpy().reshape(lat_size,lon_size) # robust trace norm

im0 = ax0.pcolormesh(lon_grid,lat_grid,y_to_plot_target,vmin=fmin,vmax=fmax)
im1 = ax1.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_gan,vmin=fmin,vmax=fmax)
im2 = ax2.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_vae,vmin=fmin,vmax=fmax)
# im3 = ax3.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_gan,vmin=fmin,vmax=fmax)
# im4 = ax4.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_robust_lr,vmin=-1.0,vmax=5.0)
# im5 = ax5.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_trace,vmin=fmin,vmax=fmax)
# im6 = ax6.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_robust_trace,vmin=-1.0,vmax=2.0)

def animate_maps(i):

    y_to_plot_target = y_test_tmp[idx_run,i,:].detach().numpy().reshape(lat_size,lon_size)
    y_to_plot_pred_ridge = y_pred_ridge[idx_run,i,:].detach().numpy().reshape(lat_size,lon_size) # ridge
    y_to_plot_pred_vae = y_pred_vae[idx_run,i,:].detach().numpy().reshape(lat_size,lon_size) # ridge low rank
    y_to_plot_pred_gan = y_pred_gan[idx_run,i,:].detach().numpy().reshape(lat_size,lon_size) # robust 
    # y_to_plot_pred_robust_lr = y_pred_robust_lr[idx_run,i,:].detach().numpy().reshape(lat_size,lon_size) # robust low rank

    im0 = ax0.pcolormesh(lon_grid,lat_grid,y_to_plot_target,vmin=fmin,vmax=fmax)
    im1 = ax1.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_gan,vmin=fmin,vmax=fmax)
    im2 = ax2.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_vae,vmin=fmin,vmax=fmax)
    # im3 = ax3.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_gan,vmin=fmin,vmax=fmax)
    # im4 = ax4.pcolormesh(lon_grid,lat_grid,y_to_plot_pred_robust_lr,vmin=-1.0,vmax=5.0)
    
plt.colorbar(im0, ax=ax0, shrink=0.3)
plt.colorbar(im1, ax=ax1, shrink=0.3)
plt.colorbar(im2, ax=ax2, shrink=0.3)
# plt.colorbar(im3, ax=ax3, shrink=0.3)
# plt.colorbar(im4, ax=ax4, shrink=0.3)
animation.FuncAnimation(fig0, animate_maps, frames=time_period)

In [70]:
y_pred_vae

tensor([[[    nan,     nan,     nan,  ..., -0.1351, -0.0970,     nan],
         [    nan,     nan,     nan,  ..., -0.1571, -0.1298,     nan],
         [    nan,     nan,     nan,  ..., -0.1502, -0.1243,     nan],
         ...,
         [    nan,     nan,     nan,  ...,  0.0294,  0.0677,     nan],
         [    nan,     nan,     nan,  ...,  0.0311,  0.0701,     nan],
         [    nan,     nan,     nan,  ...,  0.0327,  0.0722,     nan]],

        [[    nan,     nan,     nan,  ..., -0.1689, -0.1103,     nan],
         [    nan,     nan,     nan,  ..., -0.1959, -0.1513,     nan],
         [    nan,     nan,     nan,  ..., -0.1923, -0.1591,     nan],
         ...,
         [    nan,     nan,     nan,  ..., -0.0483, -0.0060,     nan],
         [    nan,     nan,     nan,  ..., -0.0469, -0.0046,     nan],
         [    nan,     nan,     nan,  ..., -0.0456, -0.0032,     nan]],

        [[    nan,     nan,     nan,  ..., -0.6123, -0.6714,     nan],
         [    nan,     nan,     nan,  ..., -0

In [69]:
y_pred_gan

tensor([[[    nan,     nan,     nan,  ..., -0.1988, -0.2518,     nan],
         [    nan,     nan,     nan,  ..., -0.3082, -0.3662,     nan],
         [    nan,     nan,     nan,  ..., -0.3486, -0.4077,     nan],
         ...,
         [    nan,     nan,     nan,  ...,  0.5273,  0.6249,     nan],
         [    nan,     nan,     nan,  ...,  0.5410,  0.6416,     nan],
         [    nan,     nan,     nan,  ...,  0.5518,  0.6547,     nan]],

        [[    nan,     nan,     nan,  ..., -0.1988, -0.2518,     nan],
         [    nan,     nan,     nan,  ..., -0.3082, -0.3662,     nan],
         [    nan,     nan,     nan,  ..., -0.3486, -0.4077,     nan],
         ...,
         [    nan,     nan,     nan,  ...,  0.5273,  0.6249,     nan],
         [    nan,     nan,     nan,  ...,  0.5410,  0.6416,     nan],
         [    nan,     nan,     nan,  ...,  0.5518,  0.6547,     nan]],

        [[    nan,     nan,     nan,  ..., -0.1988, -0.2518,     nan],
         [    nan,     nan,     nan,  ..., -0