# Convolutional Autoencoder
This notebook makes use of the **Autoencoder**, which is used to reduce the dimensionality of our dataset in a non-linear way. Furthermore, we then apply **k-means Clustering** as in our last notebook in our new created **Latent Space** in lower dimension. We do so, to get rid of less important variables and achieve a better Clustering.

In [1]:
# !pip install cartopy xarray matplotlib netCDF4 torch torchinfo

In [2]:
import helper_functions
import importlib
from ipywidgets import FloatSlider
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from ipywidgets import interact, IntSlider
from sklearn.decomposition import PCA
import numpy as np
import pandas as pd
import matplotlib.dates as mdates

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
from torch.utils.data import DataLoader, random_split
importlib.reload(helper_functions)

<module 'helper_functions' from '/home/jovyan/spatiotemporal-mining-medsea/notebooks/helper_functions.py'>

## Data Loading & Preprocessing

In [3]:
ds = xr.open_dataset("/home/jovyan/spatiotemporal-mining-medsea/data/medsea.nc")

In [None]:
X_np, M_np = helper_functions.preprocessing_conv(ds, ["thetao", "so"], [50, 300, 1000], 3)
X = torch.tensor(X_np)  # (B, C, H, W)
M = torch.tensor(M_np)
print(X.shape)

## Splitting

In [5]:
from torch.utils.data import Dataset

class MaskedDataset(Dataset):
    def __init__(self, X, M):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.M = torch.tensor(M, dtype=torch.float32)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.M[idx]

In [6]:
full_dataset = MaskedDataset(X_np, M_np)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(val_set, batch_size=16)

## The Architecture

In [7]:
import torch
import torch.nn as nn

class CAE(nn.Module):
    def __init__(self, in_channels, latent_dim=3, dropout_p=0.2, channels=[32,64,128]):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1),   # → 102×257
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(),
            nn.Dropout2d(p=dropout_p),                             # spatial dropout

            nn.Conv2d(32, 64, 3, stride=2, padding=1),             # → 51×129
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(),
            nn.Dropout2d(p=dropout_p),

            nn.Conv2d(64, 128, 3, stride=2, padding=1),            # → 26×65
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(),
            nn.Dropout2d(p=dropout_p),
        )

        self.unflatten_shape = (128, 26, 65)
        self.flatten = nn.Flatten()
        self.fc_enc = nn.Sequential(
            nn.Linear(128 * 26 * 65, 256),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(256, 64),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(64, latent_dim),
        )
        self.fc_dec = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(64, 256),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(256, 128 * 26 * 65),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(),
            nn.Dropout2d(p=dropout_p),

            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(),
            nn.Dropout2d(p=dropout_p),

            nn.ConvTranspose2d(32, in_channels, 3, stride=2, padding=1, output_padding=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        z = self.fc_enc(self.flatten(x))
        x = self.fc_dec(z)
        x = x.view(x.size(0), *self.unflatten_shape)
        x = self.decoder(x)
        return x[:, :, :203, :514]


In [8]:
# import torch
# import torch.nn as nn

# class CAE(nn.Module):
#     def __init__(self, in_channels, latent_dim=3, dropout_p=0.1, channels=[32, 64, 128]):
#         super().__init__()

#         # Dynamischer Aufbau des Encoders
#         encoder_layers = []
#         prev_channels = in_channels
#         for ch in channels:
#             encoder_layers.extend([
#                 nn.Conv2d(prev_channels, ch, 3, stride=2, padding=1),
#                 nn.InstanceNorm2d(ch),
#                 nn.LeakyReLU(),
#                 nn.Dropout2d(p=dropout_p)
#             ])
#             prev_channels = ch
#         self.encoder = nn.Sequential(*encoder_layers)

#         # Letzter Shape nach Encoding (du musst diese Werte selbst setzen je nach input size!)
#         # Beispiel: 3 Pooling-Stufen bei Input 203x514 → 203//8 ≈ 25, 514//8 ≈ 64
#         self.unflatten_shape = (channels[-1], 26, 65)
#         flatten_dim = channels[-1] * self.unflatten_shape[1] * self.unflatten_shape[2]

#         self.flatten = nn.Flatten()
#         self.fc_enc = nn.Sequential(
#             nn.Linear(flatten_dim, 256),
#             nn.LeakyReLU(),
#             nn.Dropout(p=dropout_p),
#             nn.Linear(256, 64),
#             nn.LeakyReLU(),
#             nn.Dropout(p=dropout_p),
#             nn.Linear(64, latent_dim),
#         )
#         self.fc_dec = nn.Sequential(
#             nn.Linear(latent_dim, 64),
#             nn.LeakyReLU(),
#             nn.Dropout(p=dropout_p),
#             nn.Linear(64, 256),
#             nn.LeakyReLU(),
#             nn.Dropout(p=dropout_p),
#             nn.Linear(256, flatten_dim),
#         )

#         # Dynamischer Decoder-Aufbau (inverse Reihenfolge)
#         decoder_layers = []
#         rev_channels = list(reversed(channels))
#         for i in range(len(rev_channels) - 1):
#             decoder_layers.extend([
#                 nn.ConvTranspose2d(rev_channels[i], rev_channels[i+1], 3, stride=2, padding=1, output_padding=1),
#                 nn.InstanceNorm2d(rev_channels[i+1]),
#                 nn.LeakyReLU(),
#                 nn.Dropout2d(p=dropout_p)
#             ])
#         decoder_layers.append(
#             nn.ConvTranspose2d(rev_channels[-1], in_channels, 3, stride=2, padding=1, output_padding=1)
#         )
#         self.decoder = nn.Sequential(*decoder_layers)

#     def forward(self, x):
#         x = self.encoder(x)
#         z = self.fc_enc(self.flatten(x))
#         x = self.fc_dec(z)
#         x = x.view(x.size(0), *self.unflatten_shape)
#         x = self.decoder(x)
#         return x[:, :, :203, :514]  # optional crop


In [9]:
train_losses = []
val_losses = []

## Loss Function

In [10]:
def masked_mse(x_recon, x_true, mask):
    loss = ((x_recon - x_true) ** 2) * mask
    return loss.sum() / mask.sum()

## Training Loop

In [11]:
def train(num_epochs: int):
    torch.cuda.empty_cache()

    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0.0

        for x, mask in train_loader:
            x = x.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()
            x_recon = model(x)
            loss = masked_mse(x_recon, x, mask)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item() * x.size(0)

        train_loss = running_train_loss / len(train_loader.dataset)
        train_losses.append(train_loss)

        # Validation
        model.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for x, mask in test_loader:
                x = x.to(device)
                mask = mask.to(device)
                x_recon = model(x)
                loss = masked_mse(x_recon, x, mask)
                running_val_loss += loss.item() * x.size(0)

        val_loss = running_val_loss / len(test_loader.dataset)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    return train_losses, val_losses


## Setup

In [12]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CAE(in_channels=X_np.shape[1], latent_dim=3, dropout_p=0.1 ,channels=[32, 64, 128]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from torchinfo import summary
summary(model, input_size=(1, X_np.shape[1], 203, 514))

Layer (type:depth-idx)                   Output Shape              Param #
CAE                                      [1, 6, 203, 514]          --
├─Sequential: 1-1                        [1, 128, 26, 65]          --
│    └─Conv2d: 2-1                       [1, 32, 102, 257]         1,760
│    └─InstanceNorm2d: 2-2               [1, 32, 102, 257]         --
│    └─LeakyReLU: 2-3                    [1, 32, 102, 257]         --
│    └─Dropout2d: 2-4                    [1, 32, 102, 257]         --
│    └─Conv2d: 2-5                       [1, 64, 51, 129]          18,496
│    └─InstanceNorm2d: 2-6               [1, 64, 51, 129]          --
│    └─LeakyReLU: 2-7                    [1, 64, 51, 129]          --
│    └─Dropout2d: 2-8                    [1, 64, 51, 129]          --
│    └─Conv2d: 2-9                       [1, 128, 26, 65]          73,856
│    └─InstanceNorm2d: 2-10              [1, 128, 26, 65]          --
│    └─LeakyReLU: 2-11                   [1, 128, 26, 65]          --
│   

In [None]:
num_epochs = 100
train_losses, val_losses = train(num_epochs)

Epoch 1/100 | Train Loss: 1.0240 | Val Loss: 0.9954
Epoch 2/100 | Train Loss: 1.0031 | Val Loss: 0.9942


## Evaluation

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(train_losses[190:], label="Train Loss", marker='o')
plt.plot(val_losses[190:], label="Validation Loss", marker='x')
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


## Rconstructing

In [None]:
X_recon_all = helper_functions.reconstruct_in_batches(X, model, device, batch_size=16)

M_tensor = torch.tensor(M_np[:, :, :203, :514], dtype=torch.float32)
valid_flat_mask = (M_tensor[0] > 0).reshape(-1).cpu().numpy()

X_recon_all = helper_functions.reconstruction_to_vector_masked_positions(X_recon_all, valid_flat_mask)

In [None]:
# Schritt 1 – Originale z-Werte stacken
original_z = helper_functions.preprocessing(ds, ["thetao"], [50], "location")

# Schritt 2 – Validitätsmaske vorbereiten
M_mask = M_tensor[0].reshape(-1) > 0  # (C * H * W,)
location_coords = original_z.location.values
time_coords = original_z.time.values

# Schritt 3 – CAE-Rekonstruktion erzeugen
X_recon = helper_functions.reconstruct_in_batches(X, model, device, batch_size=16)

# Schritt 4 – Nur gültige Positionen extrahieren
flat_recon = helper_functions.reconstruction_to_vector_masked_positions(X_recon, M_mask)

# Schritt 5 – Reconstructed z_stack erzeugen
# reconstructed_z = build_z_stack(flat_recon, location_coords, time_coords)
reconstructed_z = helper_functions.reconstructed_to_stack(ds, "thetao", 50, flat_recon)

# Schritt 6 – Vergleich anzeigen
helper_functions.plot_reconstruction_comparison(
    z_stack_original=original_z,
    z_stack_recon=reconstructed_z,
    time_indices=[0, 15, 31, 190],
    cmin=-2,
    cmax=2
)

## Clustering with K-Means

### Clustering the reconstructions

### Latent = 42

In [None]:
importlib.reload(helper_functions)

In [None]:
k=9
labels = helper_functions.apply_kmeans(X_recon_all, k)
labels += 1

### Depth = 50

In [None]:
recon_temp_50 = helper_functions.reconstructed_to_stack(ds, "thetao", 50, X_recon_all)

helper_functions.plot_cluster_timeline(recon_temp_50, labels)
helper_functions.plot_average_cluster(recon_temp_50, labels, -2, 2)

In [None]:
recon_so_50 = helper_functions.reconstructed_to_stack(ds, "so", 50, X_recon_all)
helper_functions.plot_average_cluster(recon_so_50, labels, -2, 2)

### Depth = 300

In [None]:
recon_temp_300 = helper_functions.reconstructed_to_stack(ds, "thetao", 300, X_recon_all)
helper_functions.plot_average_cluster(recon_temp_300, labels, -2, 2)

In [None]:
recon_so_300 = helper_functions.reconstructed_to_stack(ds, "so", 300, X_recon_all)
helper_functions.plot_average_cluster(recon_so_300, labels, -2, 2)

### Depth = 1000

In [None]:
recon_temp_1000 = helper_functions.reconstructed_to_stack(ds, "thetao", 300, X_recon_all)
helper_functions.plot_average_cluster(recon_temp_50, labels, -2, 2)

In [None]:
recon_so_1000 = helper_functions.reconstructed_to_stack(ds, "so", 1000, X_recon_all)
helper_functions.plot_average_cluster(recon_so_1000, labels, -2, 2)