# Analysis of DeepClusterV2 results

This notebook analyzes the results of the DeepClusterV2 (DCv2) algorithm.

The result of DCv2 is a sequence of labels that reflect cluster assignments.
Since we make use of timeseries weather data, theses assignments reflect large-scale weather regimes (LSWRs).

The goal is to analyze the statistics of each LSWR, i.e. 

* Abundance, i.e. how often it occurs
* Duration of occurrence

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib
from typing import Iterator

import datetime
import re

import torch
import yaml
import matplotlib as mpl
import openTSNE
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt

import a6

COLORS = {
    -1: "#000000",
    0: "#FFFF00",
    1: "#1CE6FF",
    2: "#FF34FF",
    3: "#FF4A46",
    4: "#008941",
    5: "#006FA6",
    6: "#A30059",
    7: "#FFDBE5",
    8: "#7A4900",
    9: "#0000A6",
    10: "#63FFAC",
    11: "#B79762",
    12: "#004D43",
    13: "#8FB0FF",
    14: "#997D87",
    15: "#5A0007",
    16: "#809693",
    17: "#FEFFE6",
    18: "#1B4400",
    19: "#4FC601",
    20: "#3B5DFF",
    21: "#4A3B53",
    22: "#FF2F80",
    23: "#61615A",
    24: "#BA0900",
    25: "#6B7900",
    26: "#00C2A0",
    27: "#FFAA92",
    28: "#FF90C9",
    29: "#B903AA",
    30: "#D16100",
    31: "#DDEFFF",
    32: "#000035",
    33: "#7B4F4B",
    34: "#A1C299",
    35: "#300018",
    36: "#0AA6D8",
    37: "#013349",
    38: "#00846F",
}

In [None]:
path = pathlib.Path("/home/fabian/data/8080567")
epoch = 1199


def _load_tensor_file(name, epoch: int | None = None):
    epoch = "" if epoch is None else f"epoch-{epoch}-"
    return torch.load(
        path / f"tensors/{epoch}{name}", map_location=torch.device("cpu")
    )


centroids = _load_tensor_file("centroids.pt", epoch=epoch)
centroid_indexes = _load_tensor_file("centroid-indexes.pt", epoch=epoch)
assignments = _load_tensor_file("assignments.pt", epoch=epoch)
embeddings = _load_tensor_file("embeddings.pt", epoch=epoch)
indexes = _load_tensor_file("indexes.pt", epoch=epoch)
distances = _load_tensor_file("distances.pt", epoch=epoch)

with open(path / "image-samples.yaml") as f:
    sample_indexes = yaml.safe_load(f)

In [None]:
# unique orders the tensor
centroid_indexes = centroid_indexes.unique()

In [None]:
def get_number_of_unassigend_samples(labels: torch.Tensor) -> int:
    unassigned_sample_indexes = (labels == -1).nonzero(as_tuple=True)[0]
    n_unassigned_samples = len(unassigned_sample_indexes)
    return n_unassigned_samples


get_number_of_unassigend_samples(assignments)

In [None]:
def get_index_of_crop(embs: torch.Tensor) -> int:
    for i in range(embs.shape[0]):
        # The embeddings are all zeros, except for the
        # index of the crops used for the last iteration.
        if torch.count_nonzero(embs[i]) > 0:
            return i
    raise ValueError("All embeddings are zero: %s", embs)


# j defines the index of the crops used for the last iteration.
j = get_index_of_crop(embeddings[-1])
j

In [None]:
def _fit_tsne(
    embeddings: torch.Tensor, centroids_: torch.Tensor
) -> Iterator[tuple[tuple[float, float], tuple[float, float]]]:
    result = openTSNE.TSNE().fit(embeddings.cpu())
    return zip(*result), zip(*result[centroids_])


(x, y), (x_centroids, y_centroids) = _fit_tsne(
    embeddings[-1][j], centroids_=centroid_indexes
)

In [None]:
def _get_colors(tensor: torch.Tensor | list[str | int]) -> list[str]:
    return [COLORS[int(i)] for i in tensor]


fig = plt.figure()
ax = fig.add_subplot()
colors = _get_colors(assignments[-1])
centroids = assignments[-1][centroid_indexes]
colors_centroids = _get_colors(centroids)
ax.scatter(x, y, c=colors, s=1)
ax.scatter(
    x_centroids,
    y_centroids,
    facecolor=colors_centroids,
    edgecolor="black",
    linewidth=1,
    s=20,
    marker="+",
)
ax.legend()

In [None]:
subset = assignments[-1][:364]
reshaped = subset.reshape((-1, 7))
colors = _get_colors(subset)
cmap = mpl.colors.LinearSegmentedColormap.from_list(
    "Custom cmap", colors, len(colors)
)
plt.pcolormesh(reshaped, cmap=cmap)
plt.colorbar()

In [None]:
def get_start_and_end_date_of_samples(
    samples: dict,
) -> tuple[datetime.datetime, datetime.datetime]:
    def to_datetime(date: str) -> datetime.datetime:
        return datetime.datetime.strptime(date, "%Y-%m-%d")

    first = samples[0]
    last = samples[list(sample_indexes.keys())[-1]]
    regex = re.compile(r"([0-9]{4}-[0-9]{2}-[0-9]{2})T?")
    [first_date_str] = regex.findall(first)
    [last_date_str] = regex.findall(last)
    return to_datetime(first_date_str), to_datetime(last_date_str)


start_date, end_date = get_start_and_end_date_of_samples(sample_indexes)

dates = pd.date_range(start=start_date, end=end_date, freq="1D")

labels = xr.DataArray(
    data=assignments[-1].numpy(), coords={"time": dates}, dims=["time"]
)

modes = a6.modes.methods.determine_lifetimes_of_modes(labels)

In [None]:
modes[3].statistics

In [None]:
colors = _get_colors(range(modes.size))
a6.plotting.modes.plot_modes_durations(
    modes, colors=colors, start_at_index_0=True
)