# Variational Auto Encoder
This notebook demonstrates the training of a Variational Autoencoder (VAE) to learn transformations from input climate data to corresponding forced responses.

In [1]:
# Import necessary libraries
import pandas as pd
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os, sys
import random
import warnings


from torch.utils.data import DataLoader

# Add utility paths
sys.path.append(os.path.join(os.getcwd(), 'utils'))

# Import utility functions
from utils.data_loading import *
from utils.data_processing import *
from utils.vae import *
from utils.animation import *
from utils.metrics import *
from utils.pipeline import *

# Enable autoreload
%reload_ext autoreload
%autoreload 2

# Suppress warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Define data path
current_dir = os.getcwd()
data_path = os.path.join(current_dir, 'data')
print(f"Data path: {data_path}")

Data path: /Users/lharriso/Documents/GitHub/gm4cs-l/data


In [2]:
# Use MPS / Cuda or CPU if none of the options are available
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
random.seed(42)

Using device: mps


In [3]:
# Load the data
filename = os.path.join(data_path, 'ssp585_time_series.pkl')
data, nan_mask = preprocess_data(data_path, filename)

Loading data from /Users/lharriso/Documents/GitHub/gm4cs-l/data/ssp585_time_series.pkl
Data loaded successfully.
Filtering data...
Data loaded successfully.
Filtering data...


100%|██████████| 72/72 [00:00<00:00, 54967.22it/s]
100%|██████████| 72/72 [00:00<00:00, 54967.22it/s]


Data filtered. Kept 34 models
Creating NaN mask...


100%|██████████| 34/34 [00:02<00:00, 12.84it/s]
100%|██████████| 34/34 [00:02<00:00, 12.84it/s]


NaN mask created.
Masking out NaN values...


100%|██████████| 34/34 [00:02<00:00, 15.97it/s]
100%|██████████| 34/34 [00:02<00:00, 15.97it/s]


NaN values masked out.
Reshaping data...


100%|██████████| 34/34 [00:04<00:00,  7.37it/s]
100%|██████████| 34/34 [00:04<00:00,  7.37it/s]


Data reshaped.
Adding the forced response to the data...


100%|██████████| 34/34 [00:04<00:00,  8.04it/s]
100%|██████████| 34/34 [00:04<00:00,  8.04it/s]


Forced response added.
Removing NaN values from the grid...


100%|██████████| 34/34 [00:03<00:00, 11.03it/s]
100%|██████████| 34/34 [00:03<00:00, 11.03it/s]


NaN values removed.


In [4]:
# Randomly select and keep the data corresponding to n models
n = 5
model_keys = random.sample(data.keys(), n)
data = {key: value for key,value in data.items() if key in model_keys}

since Python 3.9 and will be removed in a subsequent version.
  model_keys = random.sample(data.keys(), n)


In [5]:
# Select one of the models randomly for testing and the rest for training according to the leave-one-out strategy
test_model = random.choice(list(data.keys()))
train_models = [model for model in data.keys() if model != test_model]

# Create the training and testing datasets
train_data = {model: data[model] for model in train_models}
test_data = {test_model: data[test_model]}

print(f"Training models: {train_models}")
print(f"Testing model: {test_model}")

Training models: ['EC-Earth3', 'E3SM-2-0', 'GISS-E2-1-G', 'ACCESS-ESM1-5']
Testing model: GISS-E2-2-G


In [9]:
# Create dataset
train_dataset = ClimateDataset(train_data)
test_dataset = ClimateDataset(test_data)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Print dataset sizes
print(f'Training dataset size: {len(train_dataset)}')
print(f'Testing dataset size: {len(test_dataset)}')

Creating datasets...


Processing models:   0%|          | 0/4 [00:00<?, ?it/s]

Processing models: 100%|██████████| 4/4 [00:00<00:00, 46474.28it/s]

Creating datasets...



Processing models: 100%|██████████| 1/1 [00:00<00:00, 18157.16it/s]
Processing models: 100%|██████████| 1/1 [00:00<00:00, 18157.16it/s]


Training dataset size: 124
Testing dataset size: 11


In [10]:
# Initialize the VAE model
input_dim = train_dataset.inputs.shape[1] * train_dataset.inputs.shape[2]  # Flattened input dimensions
latent_dim = 100
hidden_dim = 50
mu_var_dim = 10
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
vae_model = VAE(input_dim=input_dim, hidden_dim = hidden_dim, latent_dim=latent_dim, device=device, mu_var_dim=mu_var_dim).to(device)

# Apply weight initialization
vae_model.apply(initialize_weights)

# Define optimizer
optimizer = torch.optim.Adam(vae_model.parameters(), lr=1e-3)

# Train the VAE
train_vae(vae_model, train_loader, optimizer, epochs=10, device=device)

Using device: mps
Input dimensions: torch.Size([16, 165, 6523])
Output dimensions: torch.Size([16, 165, 6523])
Input dimensions: torch.Size([16, 165, 6523])
Output dimensions: torch.Size([16, 165, 6523])


  0%|          | 0/10 [00:00<?, ?it/s]

NaN detected in reconstructed output at Epoch 1, Batch 3
NaN detected in mean at Epoch 1, Batch 3
NaN detected in logvar at Epoch 1, Batch 3
NaN detected in loss at Epoch 1, Batch 3
NaN detected in reconstructed output at Epoch 1, Batch 4
NaN detected in mean at Epoch 1, Batch 4
NaN detected in logvar at Epoch 1, Batch 4
NaN detected in loss at Epoch 1, Batch 4
NaN detected in reconstructed output at Epoch 1, Batch 5
NaN detected in mean at Epoch 1, Batch 5
NaN detected in logvar at Epoch 1, Batch 5
NaN detected in loss at Epoch 1, Batch 5
NaN detected in reconstructed output at Epoch 1, Batch 6
NaN detected in mean at Epoch 1, Batch 6
NaN detected in logvar at Epoch 1, Batch 6
NaN detected in loss at Epoch 1, Batch 6
NaN detected in reconstructed output at Epoch 1, Batch 7
NaN detected in mean at Epoch 1, Batch 7
NaN detected in logvar at Epoch 1, Batch 7
NaN detected in loss at Epoch 1, Batch 7
NaN detected in reconstructed output at Epoch 1, Batch 6
NaN detected in mean at Epoch 1, 

 10%|█         | 1/10 [00:01<00:12,  1.38s/it]

Epoch 1, Average Loss: nan
NaN detected in reconstructed output at Epoch 2, Batch 1
NaN detected in mean at Epoch 2, Batch 1
NaN detected in logvar at Epoch 2, Batch 1
NaN detected in loss at Epoch 2, Batch 1
NaN detected in reconstructed output at Epoch 2, Batch 2
NaN detected in mean at Epoch 2, Batch 2
NaN detected in logvar at Epoch 2, Batch 2
NaN detected in loss at Epoch 2, Batch 2
NaN detected in reconstructed output at Epoch 2, Batch 3
NaN detected in mean at Epoch 2, Batch 3
NaN detected in logvar at Epoch 2, Batch 3
NaN detected in loss at Epoch 2, Batch 3
NaN detected in reconstructed output at Epoch 2, Batch 4
NaN detected in mean at Epoch 2, Batch 4
NaN detected in logvar at Epoch 2, Batch 4
NaN detected in loss at Epoch 2, Batch 4
NaN detected in reconstructed output at Epoch 2, Batch 5
NaN detected in mean at Epoch 2, Batch 5
NaN detected in logvar at Epoch 2, Batch 5
NaN detected in loss at Epoch 2, Batch 5
NaN detected in reconstructed output at Epoch 2, Batch 3
NaN de

 20%|██        | 2/10 [00:01<00:07,  1.10it/s]

NaN detected in reconstructed output at Epoch 2, Batch 6
NaN detected in mean at Epoch 2, Batch 6
NaN detected in logvar at Epoch 2, Batch 6
NaN detected in loss at Epoch 2, Batch 6
NaN detected in reconstructed output at Epoch 2, Batch 7
NaN detected in mean at Epoch 2, Batch 7
NaN detected in logvar at Epoch 2, Batch 7
NaN detected in loss at Epoch 2, Batch 7
NaN detected in reconstructed output at Epoch 2, Batch 8
NaN detected in mean at Epoch 2, Batch 8
NaN detected in logvar at Epoch 2, Batch 8
NaN detected in loss at Epoch 2, Batch 8
Epoch 2, Average Loss: nan
NaN detected in reconstructed output at Epoch 3, Batch 1
NaN detected in mean at Epoch 3, Batch 1
NaN detected in logvar at Epoch 3, Batch 1
NaN detected in loss at Epoch 3, Batch 1
NaN detected in reconstructed output at Epoch 3, Batch 2
NaN detected in mean at Epoch 3, Batch 2
NaN detected in logvar at Epoch 3, Batch 2
NaN detected in loss at Epoch 3, Batch 2
NaN detected in reconstructed output at Epoch 3, Batch 3
NaN de

 30%|███       | 3/10 [00:02<00:05,  1.31it/s]

NaN detected in reconstructed output at Epoch 3, Batch 7
NaN detected in mean at Epoch 3, Batch 7
NaN detected in logvar at Epoch 3, Batch 7
NaN detected in loss at Epoch 3, Batch 7
NaN detected in reconstructed output at Epoch 3, Batch 8
NaN detected in mean at Epoch 3, Batch 8
NaN detected in logvar at Epoch 3, Batch 8
NaN detected in loss at Epoch 3, Batch 8
Epoch 3, Average Loss: nan
NaN detected in reconstructed output at Epoch 4, Batch 1
NaN detected in mean at Epoch 4, Batch 1
NaN detected in logvar at Epoch 4, Batch 1
NaN detected in loss at Epoch 4, Batch 1
NaN detected in reconstructed output at Epoch 4, Batch 2
NaN detected in mean at Epoch 4, Batch 2
NaN detected in logvar at Epoch 4, Batch 2
NaN detected in loss at Epoch 4, Batch 2
NaN detected in reconstructed output at Epoch 4, Batch 3
NaN detected in mean at Epoch 4, Batch 3
NaN detected in logvar at Epoch 4, Batch 3
NaN detected in loss at Epoch 4, Batch 3
NaN detected in reconstructed output at Epoch 4, Batch 4
NaN de

 40%|████      | 4/10 [00:03<00:04,  1.43it/s]

NaN detected in reconstructed output at Epoch 4, Batch 8
NaN detected in mean at Epoch 4, Batch 8
NaN detected in logvar at Epoch 4, Batch 8
NaN detected in loss at Epoch 4, Batch 8
Epoch 4, Average Loss: nan
NaN detected in reconstructed output at Epoch 5, Batch 1
NaN detected in mean at Epoch 5, Batch 1
NaN detected in logvar at Epoch 5, Batch 1
NaN detected in loss at Epoch 5, Batch 1
NaN detected in reconstructed output at Epoch 5, Batch 2
NaN detected in mean at Epoch 5, Batch 2
NaN detected in logvar at Epoch 5, Batch 2
NaN detected in loss at Epoch 5, Batch 2
NaN detected in reconstructed output at Epoch 5, Batch 3
NaN detected in mean at Epoch 5, Batch 3
NaN detected in logvar at Epoch 5, Batch 3
NaN detected in loss at Epoch 5, Batch 3
NaN detected in reconstructed output at Epoch 5, Batch 4
NaN detected in mean at Epoch 5, Batch 4
NaN detected in logvar at Epoch 5, Batch 4
NaN detected in loss at Epoch 5, Batch 4
NaN detected in reconstructed output at Epoch 5, Batch 5
NaN de

 50%|█████     | 5/10 [00:03<00:03,  1.50it/s]

NaN detected in reconstructed output at Epoch 5, Batch 6
NaN detected in mean at Epoch 5, Batch 6
NaN detected in logvar at Epoch 5, Batch 6
NaN detected in loss at Epoch 5, Batch 6
NaN detected in reconstructed output at Epoch 5, Batch 7
NaN detected in mean at Epoch 5, Batch 7
NaN detected in logvar at Epoch 5, Batch 7
NaN detected in loss at Epoch 5, Batch 7
NaN detected in reconstructed output at Epoch 5, Batch 8
NaN detected in mean at Epoch 5, Batch 8
NaN detected in logvar at Epoch 5, Batch 8
NaN detected in loss at Epoch 5, Batch 8
Epoch 5, Average Loss: nan
NaN detected in reconstructed output at Epoch 6, Batch 1
NaN detected in mean at Epoch 6, Batch 1
NaN detected in logvar at Epoch 6, Batch 1
NaN detected in loss at Epoch 6, Batch 1
NaN detected in reconstructed output at Epoch 6, Batch 2
NaN detected in mean at Epoch 6, Batch 2
NaN detected in logvar at Epoch 6, Batch 2
NaN detected in loss at Epoch 6, Batch 2
NaN detected in reconstructed output at Epoch 6, Batch 3
NaN de

 60%|██████    | 6/10 [00:04<00:02,  1.56it/s]

NaN detected in reconstructed output at Epoch 6, Batch 7
NaN detected in mean at Epoch 6, Batch 7
NaN detected in logvar at Epoch 6, Batch 7
NaN detected in loss at Epoch 6, Batch 7
NaN detected in reconstructed output at Epoch 6, Batch 8
NaN detected in mean at Epoch 6, Batch 8
NaN detected in logvar at Epoch 6, Batch 8
NaN detected in loss at Epoch 6, Batch 8
Epoch 6, Average Loss: nan
NaN detected in reconstructed output at Epoch 7, Batch 1
NaN detected in mean at Epoch 7, Batch 1
NaN detected in logvar at Epoch 7, Batch 1
NaN detected in loss at Epoch 7, Batch 1
NaN detected in reconstructed output at Epoch 7, Batch 2
NaN detected in mean at Epoch 7, Batch 2
NaN detected in logvar at Epoch 7, Batch 2
NaN detected in loss at Epoch 7, Batch 2
NaN detected in reconstructed output at Epoch 7, Batch 3
NaN detected in mean at Epoch 7, Batch 3
NaN detected in logvar at Epoch 7, Batch 3
NaN detected in loss at Epoch 7, Batch 3
NaN detected in reconstructed output at Epoch 7, Batch 4
NaN de

 70%|███████   | 7/10 [00:04<00:01,  1.60it/s]

NaN detected in reconstructed output at Epoch 7, Batch 8
NaN detected in mean at Epoch 7, Batch 8
NaN detected in logvar at Epoch 7, Batch 8
NaN detected in loss at Epoch 7, Batch 8
Epoch 7, Average Loss: nan
NaN detected in reconstructed output at Epoch 8, Batch 1
NaN detected in mean at Epoch 8, Batch 1
NaN detected in logvar at Epoch 8, Batch 1
NaN detected in loss at Epoch 8, Batch 1
NaN detected in reconstructed output at Epoch 8, Batch 2
NaN detected in mean at Epoch 8, Batch 2
NaN detected in logvar at Epoch 8, Batch 2
NaN detected in loss at Epoch 8, Batch 2
NaN detected in reconstructed output at Epoch 8, Batch 3
NaN detected in mean at Epoch 8, Batch 3
NaN detected in logvar at Epoch 8, Batch 3
NaN detected in loss at Epoch 8, Batch 3
NaN detected in reconstructed output at Epoch 8, Batch 4
NaN detected in mean at Epoch 8, Batch 4
NaN detected in logvar at Epoch 8, Batch 4
NaN detected in loss at Epoch 8, Batch 4
NaN detected in reconstructed output at Epoch 8, Batch 5
NaN de

 80%|████████  | 8/10 [00:05<00:01,  1.63it/s]

NaN detected in reconstructed output at Epoch 8, Batch 6
NaN detected in mean at Epoch 8, Batch 6
NaN detected in logvar at Epoch 8, Batch 6
NaN detected in loss at Epoch 8, Batch 6
NaN detected in reconstructed output at Epoch 8, Batch 7
NaN detected in mean at Epoch 8, Batch 7
NaN detected in logvar at Epoch 8, Batch 7
NaN detected in loss at Epoch 8, Batch 7
NaN detected in reconstructed output at Epoch 8, Batch 8
NaN detected in mean at Epoch 8, Batch 8
NaN detected in logvar at Epoch 8, Batch 8
NaN detected in loss at Epoch 8, Batch 8
Epoch 8, Average Loss: nan
NaN detected in reconstructed output at Epoch 9, Batch 1
NaN detected in mean at Epoch 9, Batch 1
NaN detected in logvar at Epoch 9, Batch 1
NaN detected in loss at Epoch 9, Batch 1
NaN detected in reconstructed output at Epoch 9, Batch 2
NaN detected in mean at Epoch 9, Batch 2
NaN detected in logvar at Epoch 9, Batch 2
NaN detected in loss at Epoch 9, Batch 2
NaN detected in reconstructed output at Epoch 9, Batch 3
NaN de

 90%|█████████ | 9/10 [00:06<00:00,  1.65it/s]

NaN detected in reconstructed output at Epoch 9, Batch 7
NaN detected in mean at Epoch 9, Batch 7
NaN detected in logvar at Epoch 9, Batch 7
NaN detected in loss at Epoch 9, Batch 7
NaN detected in reconstructed output at Epoch 9, Batch 8
NaN detected in mean at Epoch 9, Batch 8
NaN detected in logvar at Epoch 9, Batch 8
NaN detected in loss at Epoch 9, Batch 8
Epoch 9, Average Loss: nan
NaN detected in reconstructed output at Epoch 10, Batch 1
NaN detected in mean at Epoch 10, Batch 1
NaN detected in logvar at Epoch 10, Batch 1
NaN detected in loss at Epoch 10, Batch 1
NaN detected in reconstructed output at Epoch 10, Batch 2
NaN detected in mean at Epoch 10, Batch 2
NaN detected in logvar at Epoch 10, Batch 2
NaN detected in loss at Epoch 10, Batch 2
NaN detected in reconstructed output at Epoch 10, Batch 3
NaN detected in mean at Epoch 10, Batch 3
NaN detected in logvar at Epoch 10, Batch 3
NaN detected in loss at Epoch 10, Batch 3
NaN detected in reconstructed output at Epoch 10, B

100%|██████████| 10/10 [00:06<00:00,  1.49it/s]

NaN detected in reconstructed output at Epoch 10, Batch 8
NaN detected in mean at Epoch 10, Batch 8
NaN detected in logvar at Epoch 10, Batch 8
NaN detected in loss at Epoch 10, Batch 8
Epoch 10, Average Loss: nan



