# Analysis of DeepClusterV2 results

This notebook analyzes the results of the DeepClusterV2 (DCv2) algorithm.

The result of DCv2 is a sequence of labels that reflect cluster assignments.
Since we make use of timeseries weather data, theses assignments reflect large-scale weather regimes (LSWRs).

The goal is to analyze the statistics of each LSWR, i.e. 

* Abundance, i.e. how often it occurs
* Duration of occurrence

#### Main results

![Cluster embeddings projected into 2-D space using t-SNE](./plots/dcv2-30-clusters-2d.png)

![Cluster time series](./plots/dcv2-30-clusters-time-series.png)

![Cluster durations](./plots/dcv2-30-clusters-duration.png)

* Mean duration of LSWRs varies from 1 day to 2 days.
* Some LSWRs show a large variance in duration.
* Most LSWRs have a standard deviation of ~ 12 hours.
* LSWR 3 and 25 never occur longer than 1 day.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib
from typing import Iterator

import datetime
import re

import torch
import yaml
import matplotlib as mpl
import openTSNE
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt

import a6

COLORS = {
    -1: "#000000",
    0: "#FFFF00",
    1: "#1CE6FF",
    2: "#FF34FF",
    3: "#FF4A46",
    4: "#008941",
    5: "#006FA6",
    6: "#A30059",
    7: "#FFDBE5",
    8: "#7A4900",
    9: "#0000A6",
    10: "#63FFAC",
    11: "#B79762",
    12: "#004D43",
    13: "#8FB0FF",
    14: "#997D87",
    15: "#5A0007",
    16: "#809693",
    17: "#FEFFE6",
    18: "#1B4400",
    19: "#4FC601",
    20: "#3B5DFF",
    21: "#4A3B53",
    22: "#FF2F80",
    23: "#61615A",
    24: "#BA0900",
    25: "#6B7900",
    26: "#00C2A0",
    27: "#FFAA92",
    28: "#FF90C9",
    29: "#B903AA",
    30: "#D16100",
    31: "#DDEFFF",
    32: "#000035",
    33: "#7B4F4B",
    34: "#A1C299",
    35: "#300018",
    36: "#0AA6D8",
    37: "#013349",
    38: "#00846F",
}

## Load data for a specific epoch

Loads 

* Centroid positions in embedding space
* Indexes of the samples that were used as centroids in that epoch
* Cluster assignments
* Indexes of the data samples in the dataset
* Distances of the samples to the cluster centroids in embedding space

In [None]:
path = pathlib.Path("/home/fabian/data/8080567")
epoch = 1199


def _load_tensor_file(name, epoch: int | None = None):
    epoch = "" if epoch is None else f"epoch-{epoch}-"
    return torch.load(
        path / f"tensors/{epoch}{name}", map_location=torch.device("cpu")
    )


centroids = _load_tensor_file("centroids.pt", epoch=epoch)
centroid_indexes = _load_tensor_file("centroid-indexes.pt", epoch=epoch)
assignments = _load_tensor_file("assignments.pt", epoch=epoch)
embeddings = _load_tensor_file("embeddings.pt", epoch=epoch)
indexes = _load_tensor_file("indexes.pt", epoch=epoch)
distances = _load_tensor_file("distances.pt", epoch=epoch)

with open(path / "image-samples.yaml") as f:
    sample_indexes = yaml.safe_load(f)

## Order centroid indexes

Order the centroid indexes to allow using them for indexing.

In [None]:
# unique orders the tensor
centroid_indexes = centroid_indexes.unique()

## How many samples are unassigned?

When training on multiple GPUs, it might happen that some samples don't get processed and remain unassigned. 

> _Note:_ Controlled by `config.DATA.TRAIN.DROP_LAST`. If `True`, each GPU processes `floor(n_samples / n_gpus)` samples.

In [None]:
def get_number_of_unassigend_samples(labels: torch.Tensor) -> int:
    unassigned_sample_indexes = (labels == -1).nonzero(as_tuple=True)[0]
    n_unassigned_samples = len(unassigned_sample_indexes)
    return n_unassigned_samples


get_number_of_unassigend_samples(assignments)

## Get the index of the crops used for the last iteration

DCv2 alternates the crops used for K-means clustering at each iteration. Hence, the embeddings saved to disk will be non-zero only at the index of the crop index.

E.g. if `n_crops=3` is used, and K-means is run 4 times, the index of the crops used for the last iteration will be 1.
(First iteration, `i_crops=0`, second iteration `i_crops=1`, and so forth).

In [None]:
def get_index_of_crop(embs: torch.Tensor) -> int:
    for i in range(embs.shape[0]):
        # The embeddings are all zeros, except for the
        # index of the crops used for the last iteration.
        if torch.count_nonzero(embs[i]) > 0:
            return i
    raise ValueError("All embeddings are zero: %s", embs)


# j defines the index of the crops used for the last iteration.
j = get_index_of_crop(embeddings[-1])
j

## Use t-SNE to project the sample positions from embedding space to 2D

In [None]:
def _fit_tsne(
    embeddings: torch.Tensor, centroids_: torch.Tensor
) -> Iterator[tuple[tuple[float, float], tuple[float, float]]]:
    result = openTSNE.TSNE().fit(embeddings.cpu())
    return zip(*result), zip(*result[centroids_])


(x, y), (x_centroids, y_centroids) = _fit_tsne(
    embeddings[-1][j], centroids_=centroid_indexes
)

## Plot 2-D embeddings

In [None]:
def _get_colors(tensor: torch.Tensor | list[str | int]) -> list[str]:
    return [COLORS[int(i)] for i in tensor]


fig = plt.figure()
ax = fig.add_subplot()
colors = _get_colors(assignments[-1])
centroids = assignments[-1][centroid_indexes]
colors_centroids = _get_colors(centroids)
ax.scatter(x, y, c=colors, s=1)
ax.scatter(
    x_centroids,
    y_centroids,
    facecolor=colors_centroids,
    edgecolor="black",
    linewidth=1,
    s=20,
    marker="+",
)
ax.legend()
plt.savefig("./plots/dcv2-30-clusters-2d.png")

## Plot the time series of the first `N` cluster assignments per week

In [None]:
subset = assignments[-1][:364]
reshaped = subset.reshape((-1, 7))
colors = _get_colors(subset)
cmap = mpl.colors.LinearSegmentedColormap.from_list(
    "Custom cmap", colors, len(colors)
)
plt.pcolormesh(reshaped, cmap=cmap)
plt.colorbar()

plt.savefig("./plots/dcv2-30-clusters-time-series.png")

## Calculate statstics of the cluster assignments (LSWRs)

Get all occurrences of a LSWR (cluster label) in the time series and calculate all statstics for its occurrences:

* Total abundance
* Mean duration (with standard deviation)
* Median duration

In [None]:
def get_start_and_end_date_of_samples(
    samples: dict,
) -> tuple[datetime.datetime, datetime.datetime]:
    def to_datetime(date: str) -> datetime.datetime:
        return datetime.datetime.strptime(date, "%Y-%m-%d")

    first = samples[0]
    last = samples[list(sample_indexes.keys())[-1]]
    regex = re.compile(r"([0-9]{4}-[0-9]{2}-[0-9]{2})T?")
    [first_date_str] = regex.findall(first)
    [last_date_str] = regex.findall(last)
    return to_datetime(first_date_str), to_datetime(last_date_str)


start_date, end_date = get_start_and_end_date_of_samples(sample_indexes)

dates = pd.date_range(start=start_date, end=end_date, freq="1D")

labels = xr.DataArray(
    data=assignments[-1].numpy(), coords={"time": dates}, dims=["time"]
)

modes = a6.modes.methods.determine_lifetimes_of_modes(labels)

## Show the calculated statistics for a random LSWR

In [None]:
modes[1].statistics

## Plot the mean duration of all LSWRs in the time series

In [None]:
colors = _get_colors(range(modes.size))
a6.plotting.modes.plot_modes_durations(
    modes, colors=colors, start_at_index_0=True
)

plt.savefig("./plots/dcv2-30-clusters-duration.png")