# Variational Convolutional Autoencoder

This notebook uses a **Convolutional Variational Autoencoder (CVAE)**.  
It combines the strengths of a **Convolutional Autoencoder**, which captures local spatial patterns in maps, with the **Variational Autoencoder** approach that maps each input to a smooth probability distribution in latent space.  
This way, we compress oceanographic maps into robust, compact representations that stay comparable and regularized.  
Masking handles missing or irrelevant regions, and the structured latent space makes it easier to cluster or analyze similar patterns.



In [3]:
!pip install -q -r ../../requirements.txt &> /dev/null

## Imports

In [13]:
import xarray as xr
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from ipywidgets import interact, IntSlider

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.decomposition import PCA

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split, Dataset
from torchinfo import summary
import random

import sys
sys.path.append('../..')
import helper_functions
import importlib
importlib.reload(helper_functions)

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [14]:
SEED = 27
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

trend_removal=False

Since our ConvNet works with masked map data, it naturally produces some NaNs — that's expected behavior. To avoid flooding the output with warnings, we've disabled them here.

## Data Loading & Preprocessing

In [15]:
class MaskedDataset(Dataset):
    def __init__(self, X, M):
        self.X = X
        self.M = M

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.M[idx]

### Training Data

In [16]:
ds_train = xr.open_dataset("../../data/medsea1987to2025_train.nc")

X_np, M_np = helper_functions.preprocessing_conv(ds_train, ["thetao", "so"], [50, 300, 1000], trend_removal, 1)
X_train = torch.tensor(X_np)
M_train = torch.tensor(M_np)
train_dataset = MaskedDataset(X_train, M_train)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    generator=torch.Generator().manual_seed(SEED)
)

print(f"Train set size: {len(train_dataset)}")

Train set size: 461


### Validation Data

In [17]:
ds_val = xr.open_dataset("../../data/medsea1987to2025_val.nc")

X_np, M_np = helper_functions.preprocessing_conv(ds_train, ["thetao", "so"], [50, 300, 1000], trend_removal, 1)
X_val = torch.tensor(X_np)
M_val = torch.tensor(M_np)
val_dataset = MaskedDataset(X_val, M_val)

val_size = int(0.4 * len(val_dataset))
_ , val_subset = torch.utils.data.random_split(
    val_dataset,
    [len(val_dataset) - val_size, val_size],
    generator=torch.Generator().manual_seed(SEED)
)

val_loader = DataLoader(
    val_subset,
    batch_size=32,
    shuffle=False
)

print(f"Validation set size: {len(val_subset)}")

Validation set size: 184


## The Architecture

Again, we include all depths and features for reconstruction. The difference in our ConvNet is that these features are not simply concatenated, but stacked as channels — similar to how RGB channels work in images. This allows the network to capture spatial relationships between features and depths more effectively, which might be important in our case.

In [23]:
class VCAE(nn.Module):
    def __init__(self, in_channels, latent_dim=3, dropout_p=0.2, channels=[32, 64, 128, 256, 512, 1024], input_shape=(190, 508)):
        super().__init__()

        self.input_shape = input_shape
        self.channels = channels

        encoder_layers = []
        prev_channels = in_channels
        h, w = input_shape

        for ch in channels:
            encoder_layers += [
                nn.Conv2d(prev_channels, ch, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(ch),
                nn.LeakyReLU(),
                nn.Dropout2d(p=dropout_p),
            ]
            prev_channels = ch
            h = math.floor((h + 2 * 1 - 3) / 2 + 1)
            w = math.floor((w + 2 * 1 - 3) / 2 + 1)

        self.encoder = nn.Sequential(*encoder_layers)
        self.unflatten_shape = (channels[-1], h, w)

        flat_dim = channels[-1] * h * w
        self.flatten = nn.Flatten()
        self.fc_shared = nn.Sequential(
            nn.Linear(flat_dim, 1024),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(1024, 64),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
        )
        self.fc_mean = nn.Linear(64, latent_dim)
        self.fc_logvar = nn.Linear(64, latent_dim)

        self.fc_dec = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(64, 1024),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(1024, flat_dim)
        )

        decoder_layers = []
        rev_channels = list(reversed(channels))
        for i in range(len(rev_channels) - 1):
            decoder_layers += [
                nn.ConvTranspose2d(rev_channels[i], rev_channels[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(rev_channels[i + 1]),
                nn.LeakyReLU(),
                nn.Dropout2d(p=dropout_p),
            ]

        decoder_layers += [
            nn.ConvTranspose2d(rev_channels[-1], in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        ]
        self.decoder = nn.Sequential(*decoder_layers)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x, mask=None):
        if mask is not None:
            x = x * mask
        x = self.encoder(x)
        h = self.fc_shared(self.flatten(x))
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        z = self.reparameterize(z_mean, z_logvar)
        x = self.fc_dec(z)
        x = x.view(x.size(0), *self.unflatten_shape)
        x = self.decoder(x)
        return x[:, :, :self.input_shape[0], :self.input_shape[1]], z_mean, z_logvar


### Convolutional Autoencoder

We use a simple convolutional autoencoder that compresses the input maps into a small latent vector and reconstructs them back. The encoder reduces spatial resolution through convolutional layers, while the decoder upsamples the data back using transposed convolutions.

I experimented with different numbers of layers, kernel sizes, and dropout values. The current setup gave the best trade-off between reconstruction quality and training stability.


### Hard Masking
- Since many values in the input maps are invalid (e.g. land areas), we apply a **hard mask** before feeding the data into the encoder.
- This mask zeroes out irrelevant values, so the model only learns from valid oceanic regions.
- During training, the same mask is applied to the loss function to ensure the model is not penalized for errors in masked-out areas.
- This is essential for learning robust spatial patterns without being misled by missing or irrelevant data.


## Loss Function

In [19]:
def masked_recon_loss(x_recon, x_true, mask):
    loss = ((x_recon - x_true) ** 2) * mask
    return loss.sum() / mask.sum()

## Training Loop

In [26]:
def train(num_epochs: int, kl_annealing_epochs: int = 50, bint: float = 0.1):
    torch.cuda.empty_cache()

    train_losses = []
    val_losses = []
    train_recons = []
    val_recons = []
    train_kls = []
    val_kls = []

    for epoch in range(num_epochs):
        beta = min(bint, epoch / kl_annealing_epochs * bint)
        model.train()
        running_train_recon = 0.0
        running_train_kl = 0.0

        for x, mask in train_loader:
            x = x.to(device)
            mask = mask.to(device)
            optimizer.zero_grad()
            x_recon, z_mean, z_logvar = model(x, mask=mask)
            kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1).mean()
            recon_loss = masked_recon_loss(x_recon, x, mask)
            loss = recon_loss + beta * kl_loss
            loss.backward()
            optimizer.step()
            running_train_recon += recon_loss.item() * x.size(0)
            running_train_kl += kl_loss.item() * x.size(0)

        train_recon = running_train_recon / len(train_loader.dataset)
        train_kl = running_train_kl / len(train_loader.dataset)
        train_loss = train_recon + beta * train_kl

        train_losses.append(train_loss)
        train_recons.append(train_recon)
        train_kls.append(train_kl)

        model.eval()
        running_val_recon = 0.0
        running_val_kl = 0.0

        with torch.no_grad():
            for x, mask in val_loader:
                x = x.to(device)
                mask = mask.to(device)
                x_recon, z_mean, z_logvar = model(x, mask=mask)
                kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1).mean()
                recon_loss = masked_recon_loss(x_recon, x, mask)
                running_val_recon += recon_loss.item() * x.size(0)
                running_val_kl += kl_loss.item() * x.size(0)

        val_recon = running_val_recon / len(val_loader.dataset)
        val_kl = running_val_kl / len(val_loader.dataset)
        val_loss = val_recon + beta * val_kl

        val_losses.append(val_loss)
        val_recons.append(val_recon)
        val_kls.append(val_kl)

        if (epoch+1) % 10 == 0:
            print(
                f"Epoch {epoch+1}/{num_epochs} | "
                f"β: {beta:.5f} | "
                f"Train Loss: {train_loss:.4f} (Recon: {train_recon:.4f}, KL: {train_kl:.4f}) | "
                f"Val Loss: {val_loss:.4f} (Recon: {val_recon:.4f}, KL: {val_kl:.4f})"
            )

    return train_losses, val_losses, train_recons, val_recons, train_kls, val_kls


## Setup

In [21]:
train_losses = []
val_losses = []

In [24]:
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VCAE(in_channels=X_np.shape[1], latent_dim=3, dropout_p=0.2, channels= [32, 64, 128, 256, 512, 1024]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0004)

summary(model, input_size=(1, X_np.shape[1], 190, 508))

Layer (type:depth-idx)                   Output Shape              Param #
VCAE                                     [1, 6, 190, 508]          --
├─Sequential: 1-1                        [1, 1024, 3, 8]           --
│    └─Conv2d: 2-1                       [1, 32, 95, 254]          1,760
│    └─InstanceNorm2d: 2-2               [1, 32, 95, 254]          --
│    └─LeakyReLU: 2-3                    [1, 32, 95, 254]          --
│    └─Dropout2d: 2-4                    [1, 32, 95, 254]          --
│    └─Conv2d: 2-5                       [1, 64, 48, 127]          18,496
│    └─InstanceNorm2d: 2-6               [1, 64, 48, 127]          --
│    └─LeakyReLU: 2-7                    [1, 64, 48, 127]          --
│    └─Dropout2d: 2-8                    [1, 64, 48, 127]          --
│    └─Conv2d: 2-9                       [1, 128, 24, 64]          73,856
│    └─InstanceNorm2d: 2-10              [1, 128, 24, 64]          --
│    └─LeakyReLU: 2-11                   [1, 128, 24, 64]          --
│   

In [27]:
train_losses, val_losses, train_recons, val_recons, train_kls, val_kls = train(1000, 600, 0.04)

Epoch 10/1000 | β: 0.00060 | Train Loss: 0.7305 (Recon: 0.7226, KL: 13.1507) | Val Loss: 0.6687 (Recon: 0.6627, KL: 10.1468)
Epoch 20/1000 | β: 0.00127 | Train Loss: 0.6544 (Recon: 0.6414, KL: 10.2626) | Val Loss: 0.6066 (Recon: 0.5937, KL: 10.1607)
Epoch 30/1000 | β: 0.00193 | Train Loss: 0.6341 (Recon: 0.6187, KL: 7.9901) | Val Loss: 0.5857 (Recon: 0.5742, KL: 5.9149)
Epoch 40/1000 | β: 0.00260 | Train Loss: 0.5997 (Recon: 0.5821, KL: 6.7816) | Val Loss: 0.5414 (Recon: 0.5242, KL: 6.5832)


KeyboardInterrupt: 

## Evaluation

In [None]:
# train_losses2 = [min(x,1) for x in train_losses]
# val_losses2 = [min(x,2) for x in val_losses]
helper_functions.plot_metrics([(train_losses2, "Train Loss"), (val_losses2, "Validation Loss")], "Loss")
helper_functions.plot_metrics([(train_recons, "Train MSE-Scores"), (val_recons, "Validation MSE-Scores")], "MSE-Scores")
# train_kls2 = [min(x,7) for x in train_kls]
# val_kls2 = [min(x,7) for x in val_kls]
helper_functions.plot_metrics([(train_kls2, "Train KL-Scores"), (val_kls2, "Validation KL-Scores")], "KL-Scores")

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

latents_mu = []
latents_logvar = []

with torch.no_grad():
    for batch in train_loader:
        x = batch[0].to(device).float()
        _, mu, logvar = model(x)
        latents_mu.append(mu.cpu())
        latents_logvar.append(logvar.cpu())

mu_all = torch.cat(latents_mu, dim=0)           # shape: (n_samples, latent_dim)
logvar_all = torch.cat(latents_logvar, dim=0)   # shape: (n_samples, latent_dim)

# Statistics
mu_std = mu_all.std(dim=0)
logvar_mean = logvar_all.mean(dim=0)
logvar_std = logvar_all.std(dim=0)

print("Std of mu per latent dim:")
print(mu_std)

print("\nMean of logvar per latent dim:")
print(logvar_mean)

print("\nStd of logvar per latent dim:")
print(logvar_std)

Std of mu per latent dim:
tensor([1.0984, 1.4006, 1.3363])

Mean of logvar per latent dim:
tensor([-3.4504, -3.8279, -2.5183])

Std of logvar per latent dim:
tensor([1.1129, 1.1008, 0.4396])


In [29]:
torch.save(model.state_dict(), "VCAE.pth")