# DeepClusterV2

In the first notebook, you had a first insight into xarray and the data set we're going to use.

The ERA5 data originally have a temporal resolution of 1 hour on a $30\,\mathrm{km}$ grid (0.25$^\circ$). Our data set contains data on multiple pressure levels $300, 500, 700, 850, 950\,\mathrm{hPa}$. Although the model provides a large variety of physical output quantities, we here though only make use of five:

1. Geopotential height $z$ $[10^{-2}\,]$
1. Temperature $t$ $[\mathrm{K}]$
1. Relative humidity $r$ $[\%]$
1. Azimuthal wind speed $u$ $[\mathrm{m}/\mathrm{s}]$
1. Vertical wind speed $v$ $[\mathrm{m}/\mathrm{s}]$

The coordinates used are 

1. Time
1. Pressure Level
1. Latitude
1. Longitude

The order of these follows the [CF 1.6 conventions for NetCDF data](https://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#dimensions), which is also given in the datasets metadata.

As you can see from your plots made in the previous notebook, the data cover the whole of Europe. As stated above, the data is given at a resolution of $\sim 30\,\mathrm{km}$.
Plotting the temperature field nicely enables you to identify and distinguish the continental oceanic areas.

## The problem were trying to solve

The goal of Application is to identify recurring large-scale weather patterns (or regimes, _LSWRs_) over Europe and eventually investigate whether their occurrence has an effect on the power production forecast quality of ML models.

We want to find these patterns on a daily basis, meaning we assume that each pattern roughly occurs and lasts for at least one day.

So first, we aim to develop an unsupervised clustering algorithm that is able to identify patterns in high-dimensional data: we want to make use of multiple physical quantities on multiple pressure levels over the whole of Europe.

One proven algorithm that allows unsupervised clustering of images is _DeepClusterV2_ (_DCv2_,see [the DCv2 paper](https://arxiv.org/abs/2006.09882v5)), which achieves high accuracy on typical tasks in image recognition. The procedure of DCv2 is as follows:

1. Makes use of different data augmentation strategies (random cropping, rotation, mirroring)
2. Feeds to images to a CNN (ResNet50) and an MLP, which has a certain output dimensionality.
3. Clusters the images in the low-dimensional feature space that is the output of the CNN+MLP
4. Uses the cluster assignments to run backpropagation on the CNN+MLP to adjust the weights such that similar samples get closer and closer with each iteration.


<img src="./images/dcv2-architecture.png" width="100%" height="100%">

Since DCv2 originally only works with typical RGB images (i.e. 3 input channels), we need to adjust the ResNet to allow more than 3 input channels, depending on how many pressure levels and variables per level we want to use for training.

Each sample will represent one day in our time series of ERA5 data, and we will use multiple levels and variables.

**Question:** If we use 2 pressure levels, e.g. $500\,\mathrm{hPa}$ and $950\,\mathrm{hPa}$, and the variables $z$, $r$, and $T$, how many input channels would the ResNet require?

## Now Let's Code

We will now take a look at the required code. First, we will import all required modules and set up a logger that will allow us to print any important information from anywhere in our code to the output.

In [None]:
import pathlib
import sys
import time
import logging

import numpy as np
import scipy.constants as constants
import torch
import torch.nn as nn
import xarray as xr
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
import openTSNE

import a6
import a6.dcv2._logs as logs
import a6.dcv2._averaging as averaging
import a6.plotting._colors as _colors

logging.basicConfig(
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stdout,
)
logger = logging.getLogger("notebook")
logger.info("Logger initialized")

## Data Loading

For model training, we will be using PyTorch. For image recognition, PyTorch provides an additional packaged called `torchvision` that provides a lot of extremely helpful methods and features
for image processing. One example is common transformations of images to achieve data augmentation. It also provides a set of so-called data loaders that allow to load common data types for
different ML tasks.

For supervised learning, for example, it provides a data loader that allows to reach images (e.g. from TIF format) and respective labels of each image.
However, we neither do have labels nor is our data in TIF format since we simply don't use images.

As a consequence, we have to write our own data loader. Luckily, torchvision data loaders have a very lightweight abstract interface that allows to implement data loaders for any kind of data. For more details, see the [torchvision docs](https://pytorch.org/vision/main/datasets.html).

Furthermore, we will do _some_ preprocessing of the data before feeding them to the ResNet.

In [None]:
@a6.utils.make_functional
def calculate_geopotential_height(
    data: xr.Dataset,
    scaling: float = 10.0,
) -> xr.Dataset:
    """Calculate the geopotential height from the geopotential.

    Parameter
    ---------
    data : xr.Dataset
        Data containing the geopotential.
    scaling : float, default=10.0
        Parameter used for scaling the data.
        E.g. the ERA5 geopotential is given in decameters.

    """
    data["z_h"] = data["z"] / constants.g / scaling
    return data


@a6.utils.make_functional
def drop_variables_from_dataset(
    data: xr.Dataset,
    names: str | list[str],
) -> xr.Dataset:
    """Drop variables from dataset."""
    return data.drop_vars(names)


def create_dataset(
    path: pathlib.Path,
    is_netcdf: bool,
    nmb_crops: tuple[int, ...],
    size_crops: tuple[float, ...],
    min_scale_crops: tuple[float, ...],
    max_scale_crops: tuple[float, ...],
    drop_variables: list[str] | None = None,
) -> a6.datasets.crop.Base:
    if is_netcdf:
        logger.info("Reading data from netCDF files %s", path)
        ds = xr.open_dataset(
            path,
            engine="netcdf4",
            drop_variables=drop_variables or [],
        )
        postprocessing = (
            a6.features.methods.weighting.weight_by_latitudes(
                latitudes="latitude",
                use_sqrt=True,
            )
            >> calculate_geopotential_height()
            >> drop_variables_from_dataset(names=["z"])
        )
        logger.info("Applying postprocessing to dataset")
        ds = postprocessing(ds)
        return a6.datasets.crop.MultiCropXarrayDataset(
            data_path=path,
            dataset=ds,
            nmb_crops=nmb_crops,
            size_crops=size_crops,
            min_scale_crops=min_scale_crops,
            max_scale_crops=max_scale_crops,
            return_index=True,
        )
    logger.info("Assuming image folder dataset")
    return a6.datasets.crop.MultiCropDataset(
        data_path=path,
        nmb_crops=nmb_crops,
        size_crops=size_crops,
        min_scale_crops=min_scale_crops,
        max_scale_crops=max_scale_crops,
        return_index=True,
    )

Now we can load the data and define how many augmented crops we want to train our model with, and what the parameters for these crops are.

## The Clustering Algorithm

The following code contains the core part of the algorithm that performs the spherical $K$-Means on our data in feature space.

The `init_embeddings` method initially runs our data $X_i$ through the ResNet and MLP and stores the resulting feature space vectors $Z_i$ of each sample (in the code, these are called `embeddings`).
The `indexes` tensor stores the indexes of each sample in our original data set. An index of `10`, for example, reflects the 10th day from the beginning of our data set.

The `cluster_embeddings` method then uses the embeddings of the previous epoch (or the initial embeddings in case of the first epoch), and clusters them with a spherical $K$-Means.
To do so, it randomly selects some samples from the data set as centroids and then assigns each sample to these centroids depending on their distance to each centroid.
I.e., spherical $K$-Means assigns each sample the cluster label of the cluster centroid whose positional $128$-D vector is closest to the positional vector of the sample.

Here, we store also the sample indexes, embeddings, the cluster assignments, and the indexes of the centroid samples to torch tensors on the disk, and we create some plots to illustrate the process.

In [None]:
def init_embeddings(
    dataloader: torch.utils.data.DataLoader,
    model: nn.Module,
    device: torch.device,
    feature_dimensions: int,
    crops_for_assign: tuple[int, float],
    drop_last: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    size_dataset = len(dataloader.dataset)
    logger.info("Dataset size is %i samples", size_dataset)

    if drop_last:
        size_dataset -= size_dataset % settings.model.batch_size
        logger.warning(
            "Adjusted size of memory per process due to drop_last=True to %i",
            size_dataset,
        )

    logger.info("Processing %i samples in total", size_dataset)

    indexes = torch.zeros(size_dataset).long().to(device=device)
    embeddings = torch.zeros(
        len(crops_for_assign),
        size_dataset,
        feature_dimensions,
    ).to(device=device)

    start_idx = 0
    with torch.no_grad():
        logger.info("Start initializing the embeddings")
        for index, inputs in dataloader:
            logger.info(
                "Processing %i samples to initialize embeddings", index.size(0)
            )
            n_indexes = inputs[0].size(0)
            index = index.to(device=device, non_blocking=True)

            # get embeddings
            outputs = []
            for crop_idx in crops_for_assign:
                inp = inputs[crop_idx].to(device=device, non_blocking=True)
                outputs.append(model(inp)[0])

            # fill the memory bank
            indexes[start_idx : start_idx + n_indexes] = index
            for mb_idx, embedding in enumerate(outputs):
                embeddings[mb_idx][
                    start_idx : start_idx + n_indexes
                ] = embedding
            start_idx += n_indexes
    logger.info(
        "Initialization of embeddings done with %s indexes",
        indexes.size(),
    )
    return indexes, embeddings


def cluster_embeddings(
    epoch: int,
    model,
    indexes: torch.Tensor,
    embeddings: torch.Tensor,
    size_dataset: int,
    device: torch.device,
    crops_for_assign: tuple[float, ...],
    nmb_prototypes: tuple[int, ...],
    feature_dimensions: int,
    n_epochs: int,
    nmb_kmeans_iters: int,
    plots_path: pathlib.Path = pathlib.Path("."),
) -> tuple[torch.Tensor, torch.Tensor]:
    logger.info("Clustering %i samples", size_dataset)

    # j defines which crops are used for the K-means run.
    # E.g. if the number of crops (``self.nmb_mbs``) is 2, and
    # ``self.num_clusters = [30, 30, 30, 30]``, the crops will
    # be used as following:
    #
    # 1. K=30, j=0
    # 2. K=30, j=1
    # 3. K=30, j=0
    # 4. K=30, j=1
    j = 0

    n_heads = len(nmb_prototypes)
    n_clusters = nmb_prototypes[0]

    assignments_per_prototype = (torch.zeros(n_heads, size_dataset).long()).to(
        device
    )
    indexes_per_prototype = torch.zeros(n_heads, size_dataset).long().to(device)
    centroids_indexes_per_prototype = torch.zeros(n_heads, n_clusters).to(
        device
    )

    embeddings_per_prototype = torch.zeros(
        n_heads,
        *tuple(embeddings.size()),
    ).to(device)
    distances_per_prototype = torch.zeros(n_heads, size_dataset).to(device)

    with torch.no_grad():
        for i_K, K in enumerate(nmb_prototypes):
            # run k-means

            # init with random samples as centroids from the dataset
            centroids = torch.empty(K, feature_dimensions).to(
                device=device, non_blocking=True
            )

            batch_size = len(embeddings[j])
            random_idx = torch.randperm(batch_size)[:K]
            assert len(random_idx) >= K, (
                f"Please reduce the number of centroids K={K}: "
                f"K must be smaller than batch size {batch_size}"
            )
            centroids = embeddings[j][random_idx]

            for n_iter in range(nmb_kmeans_iters + 1):
                # E step
                dot_products = torch.mm(embeddings[j], centroids.t())
                distances, assignments = dot_products.max(dim=1)

                # finish
                if n_iter == nmb_kmeans_iters:
                    break

                # M step
                where_helper = _get_indices_sparse(assignments.cpu().numpy())
                counts = (
                    torch.zeros(K).to(device=device, non_blocking=True).int()
                )
                emb_sums = torch.zeros(K, feature_dimensions).to(
                    device=device, non_blocking=True
                )
                for k in range(len(where_helper)):
                    if len(where_helper[k][0]) > 0:
                        emb_sums[k] = torch.sum(
                            embeddings[j][where_helper[k][0]],
                            dim=0,
                        )
                        counts[k] = len(where_helper[k][0])
                mask = counts > 0
                centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1)

                # normalize centroids
                centroids = nn.functional.normalize(centroids, dim=1, p=2)

            # Copy centroids to model for forwarding
            getattr(
                model.prototypes,
                "prototypes" + str(i_K),
            ).weight.copy_(centroids)

            logger.info("embeddings[j=%i]: %s", j, embeddings[j].size())

            # Save results to local tensors
            assignments_per_prototype[i_K][indexes] = assignments
            indexes_per_prototype[i_K][indexes] = indexes
            centroids_indexes_per_prototype[i_K] = random_idx
            # For the embeddings, make sure to use j for indexing
            embeddings_per_prototype[i_K][j][indexes] = embeddings[j]

            # next memory bank to use
            j = (j + 1) % len(crops_for_assign)

    return assignments_per_prototype, centroids_indexes_per_prototype


def _get_indices_sparse(data):
    cols = np.arange(data.size)
    M = csr_matrix(
        (cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size)
    )
    return [np.unravel_index(row.data, data.shape) for row in M]

## The Training Step

The `train` method trains our model (i.e. runs the samples through the CNN+MLP and performs the $K$-Means) and then computes the loss to eventually adjust the model weights via backpropagation.

The loss here is defined by the difference in cluster assignments between the different random crops of the same sample.

For example: the same sample (day in our time series) results in two random crops. If the ResNet+MLP+$K$-Means then assigns different labels to both crops, it gets "punished" because this increases the loss.
On the other hand, if both crops get the same cluster label - which is desirable - this will positively affect the loss and hence the backpropagation.
Eventually, this will push our model to bring similar samples closer to each other: their positions in feature space will evolve to eventually have clusters of similar samples.

We hope that these clusters reflect similar LSWRs.

In [None]:
def train(
    dataloader: torch.utils.data.DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    indexes: torch.Tensor,
    embeddings: torch.Tensor,
    nmb_crops: tuple[float, ...],
    crops_for_assign: tuple[float, ...],
    nmb_prototypes: tuple[int, ...],
    feature_dimensions: int,
    n_epochs: int,
    nmb_kmeans_iters: int,
    temperature: float,
    device: torch.device,
) -> tuple[tuple[int, float], torch.Tensor, torch.Tensor]:
    batch_time = averaging.AverageMeter()
    data_time = averaging.AverageMeter()
    losses = averaging.AverageMeter()

    model.train()

    cross_entropy = nn.CrossEntropyLoss()

    assignments, centroids_indexes = cluster_embeddings(
        epoch=epoch,
        model=model,
        indexes=indexes,
        embeddings=embeddings,
        size_dataset=len(dataloader.dataset),
        device=device,
        crops_for_assign=crops_for_assign,
        nmb_prototypes=nmb_prototypes,
        feature_dimensions=feature_dimensions,
        nmb_kmeans_iters=nmb_kmeans_iters,
        n_epochs=n_epochs,
    )

    logger.info("Clustering for epoch %i done", epoch)

    end = time.time()
    start_idx = 0
    for it, (idx, inputs) in enumerate(dataloader):
        logger.debug("Calculating loss for index %s", idx)

        # measure data loading time
        data_time.update(time.time() - end)

        # ============ multi-res forward passes ... ============
        # Output here returns the output for each head (prototype)
        # and hence has size ``len(settings.model.nmb_prototypes)``.
        emb, output = model(inputs)
        emb = emb.detach()
        bs = inputs[0].size(0)

        if bs == 0:
            raise RuntimeError(
                f"Batch size is zero, loss will be NaN: it={it}, idx={idx}, "
                "inputs[0]={inputs[0]}"
            )

        # ============ deepcluster-v2 loss ... ============
        loss = torch.tensor(0.0).to(device=device)
        for h in range(len(nmb_prototypes)):
            scores = output[h] / temperature
            targets = (
                assignments[h][idx]
                .repeat(sum(nmb_crops))
                .to(device=device, non_blocking=True)
            )
            loss += cross_entropy(scores, targets)
        loss /= len(nmb_prototypes)

        # ============ backward and optim step ... ============
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ============ update memory banks ... ============
        indexes[start_idx : start_idx + bs] = idx
        for i, crop_idx in enumerate(crops_for_assign):
            embeddings[i][start_idx : start_idx + bs] = emb[  # noqa: E203
                crop_idx * bs : (crop_idx + 1) * bs
            ]
        start_idx += bs

        # ============ misc ... ============
        losses.update(loss.item(), bs)
        batch_time.update(time.time() - end)
        end = time.time()

        if it % 50 == 0:
            logger.info(
                "[EPOCH %i, ITERATION %i] "
                "batch time: %s (%s) "
                "data load time: %s (%s) "
                "loss: %s (%s) "
                "lr: %s",
                epoch,
                it,
                batch_time.val,
                batch_time.avg,
                data_time.val,
                data_time.avg,
                losses.val,
                losses.avg,
                optimizer.state_dict()["param_groups"][0]["lr"],
            )

    logger.debug(
        "========= Memory Summary at epoch %i =======\n%s\n",
        epoch,
        torch.cuda.memory_summary(),
    )

    return (
        (epoch, losses.avg),
        indexes,
        embeddings,
        assignments,
        centroids_indexes,
    )

## The Training Loop

The `start` method is the "entry point" of our training: here we initialize the data loader, model and optimizer (Stochastic Gradient Descent).

Then, we perform an initial forwarding of our samples through the model via `init_embeddings` and then perform the training step (`train` method)
for the desired number of epochs.

Meanwhile, we track our training progress, most importantly the loss.

In [None]:
def start(
    train_dataset: a6.datasets.crop.Base,
    epochs: int,
    nmb_crops: tuple[float, ...],
    crops_for_assign: tuple[float, ...],
    model_architecture: a6.models.resnet.Architecture = a6.models.resnet.Architecture.ResNet50,
    batch_size: int = 64,
    drop_last: bool = False,
    hidden_mlp: int = 2048,
    feature_dimensions: int = 128,
    nmb_kmeans_iters: int = 10,
    nmb_prototypes: int = 3,
    nmb_clusters: int = 40,
    base_lr: float = 4.8,
    weight_decay: float = 1e-6,
    temperature: float = 0.1,
) -> tuple[torch.Tensor, torch.Tensor]:
    nmb_prototypes = [nmb_clusters for _ in range(nmb_prototypes)]
    logger.info("Prototypes for model: %s", nmb_prototypes)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=False,
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True,
        # ``drop_last=True`` gives each device the same amount of samples,
        # but removes some from the clustering.
        drop_last=drop_last,
        worker_init_fn=a6.utils.distributed.set_dataloader_seeds,
    )
    logger.info("Building data done with %s images loaded", len(train_dataset))

    device = torch.device("cuda:0")

    logger.info("Using device %s", device)

    # build model
    model = a6.models.resnet.Models[model_architecture](
        normalize=True,
        in_channels=train_dataset.n_channels,
        hidden_mlp=hidden_mlp,
        output_dim=feature_dimensions,
        nmb_prototypes=nmb_prototypes,
        device=device,
    )
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # Copy model to GPU
    model = model.to(device)

    logger.info(model)
    logger.info("Building model done")

    # build optimizer
    # Should be done after moving the model to GPU
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=base_lr,
        momentum=0.9,
        weight_decay=weight_decay,
    )

    logger.info("Building optimizer done")
    training_stats = logs.Stats("stats.csv", columns=("epoch", "loss"))

    indexes, embeddings = init_embeddings(
        dataloader=train_loader,
        model=model,
        device=device,
        feature_dimensions=feature_dimensions,
        crops_for_assign=crops_for_assign,
        drop_last=drop_last,
    )

    for epoch in range(epochs):
        # train the network for one epoch
        logger.info("============ Starting epoch %i ============", epoch)

        # set sampler
        # train_loader.sampler.set_epoch(epoch)

        # train the network
        scores, indexes, embeddings, assignments, centroids_indexes = train(
            dataloader=train_loader,
            model=model,
            optimizer=optimizer,
            epoch=epoch,
            n_epochs=epochs,
            indexes=indexes,
            embeddings=embeddings,
            nmb_crops=nmb_crops,
            crops_for_assign=crops_for_assign,
            nmb_prototypes=nmb_prototypes,
            feature_dimensions=feature_dimensions,
            nmb_kmeans_iters=nmb_kmeans_iters,
            device=device,
            temperature=temperature,
        )
        training_stats.update(scores)
    return indexes, embeddings, assignments, centroids_indexes

### Load the data set and define the data augmentation strategy

In [None]:
path = pathlib.Path(
    "/p/project/training2330/a6/data/ecmwf_era5/era5_pl_2017_2020_12.nc"
)
nmb_crops = (2,)
crops_for_assign = (0, 1)
size_crops = (0.75,)
min_scale_crops = (0.15,)
max_scale_crops = (1.0,)

train_dataset = create_dataset(
    path=path,
    is_netcdf=True,
    nmb_crops=nmb_crops,
    size_crops=size_crops,
    min_scale_crops=min_scale_crops,
    max_scale_crops=max_scale_crops,
)

### Start the training

In [None]:
(
    indexes,
    embeddings_per_crop,
    assignments_per_prototype,
    centroids_indexes_per_prototype,
) = start(
    train_dataset=train_dataset,
    epochs=10,
    nmb_crops=nmb_crops,
    crops_for_assign=crops_for_assign,
    batch_size=64,
    nmb_prototypes=3,
    nmb_clusters=40,
)

In [None]:
def plot_embeddings_using_tsne(
    embeddings: torch.Tensor,
    assignments: torch.Tensor,
    centroids: torch.Tensor,
) -> None:
    """Plot the embeddings of DCv2 using t-SNE.

    Args:
        embeddings (torch.Tensor, shape(n_crops, n_samples, n_embedding_dims)):
            Embeddings as produced by the ResNet.
        assignments (torch.Tensor, shape(n_samples)):
            Assignments by DCv2 for each sample.

            Used for coloring each sample in the plot.
        centroids (torch.Tensor): The indexes of the centroids.

    """
    logging.info("Creating plot for embeddings")

    _, ax = plt.subplots()

    ax.set_title(f"Embeddings for crops")

    (x, y), (x_centroids, y_centroids) = _fit_tsne(
        embeddings=embeddings, centroids=centroids
    )
    colors = _colors.create_colors_for_assigments(assignments)

    ax.scatter(x, y, c=colors, s=1)
    ax.scatter(x_centroids, y_centroids, c="red", s=20, marker="x")


def _fit_tsne(
    embeddings: torch.Tensor, centroids: torch.Tensor
) -> tuple[tuple[tuple[float, float], tuple[float, float]]]:
    result = openTSNE.TSNE().fit(embeddings.cpu())
    return zip(*result), zip(*result[centroids])


assignments_cpu = assignments_per_prototype[-1].cpu()
embeddings_cpu = embeddings_per_crop[-1].cpu()
centroids_indexes_cpu = centroids_indexes_per_prototype[-1].cpu()
plot_embeddings_using_tsne(
    embeddings=embeddings_cpu,
    assignments=assignments_cpu,
    centroids=centroids_indexes_cpu,
)
a6.plotting.assignments.plot_abundance(assignments_cpu)
a6.plotting.transitions.plot_transition_matrix_heatmap(assignments_cpu)
a6.plotting.transitions.plot_transition_matrix_clustermap(assignments_cpu)

## Possible Tasks

1. Abundance of LSWRs vs abundance of DWD Großwetterlagen
   1. total
   2. per month
2. Plot meteorological quantities for individual clusters and compare (`a6.plotting.plot_fields_for_dates`).
3. Plot meteorological quantities for samples closest to the cluster centroids and compoare (`a6.plotting.plot_fields_for_dates`).
4. Vary number of clusters (`n_clusters`) and compare the results.
5. Adapt data augmentation strategy and investigate affect on results.
6. Relation of forecast error to LSWRs (`a6.studies.grid_search.perform_forecast_model_grid_search`).