In [None]:
import pathlib
import sys
import time
import logging

import numpy as np
import scipy.constants as constants
import torch
import torch.nn as nn
import xarray as xr
from scipy.sparse import csr_matrix

import a6
import a6.dcv2._logs as logs
import a6.dcv2._averaging as averaging

logging.basicConfig(
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stdout,
)
logger = logging.getLogger("notebook")
logger.info("Logger initialized")

In [None]:
@a6.utils.make_functional
def calculate_geopotential_height(
    data: xr.Dataset,
    scaling: float = 10.0,
) -> xr.Dataset:
    """Calculate the geopotential height from the geopotential.

    Parameter
    ---------
    data : xr.Dataset
        Data containing the geopotential.
    scaling : float, default=10.0
        Parameter used for scaling the data.
        E.g. the ERA5 geopotential is given in decameters.

    """
    data["z_h"] = data["z"] / constants.g / scaling
    return data


@a6.utils.make_functional
def drop_variables_from_dataset(
    data: xr.Dataset,
    names: str | list[str],
) -> xr.Dataset:
    """Drop variables from dataset."""
    return data.drop_vars(names)


def create_dataset(
    paths: pathlib.Path | list[pathlib.Path],
    is_netcdf: bool,
    nmb_crops: tuple[int, ...],
    size_crops: tuple[float, ...],
    min_scale_crops: tuple[float, ...],
    max_scale_crops: tuple[float, ...],
    drop_variables: list[str] | None = None,
) -> a6.datasets.crop.Base:
    if is_netcdf:
        logger.info("Assuming dataset from netCDF files")
        drop_variables = drop_variables or []

        preprocessing = (
            a6.features.methods.weighting.weight_by_latitudes(
                latitudes="latitude",
                use_sqrt=True,
            )
            >> calculate_geopotential_height()
            >> drop_variables_from_dataset(names=["z"])
        )
        logger.info("Reading data from netCDF files %s", paths)
        ds = xr.open_mfdataset(
            paths,
            engine="netcdf4",
            concat_dim="time",
            combine="nested",
            coords="minimal",
            data_vars="minimal",
            preprocess=preprocessing,
            drop_variables=drop_variables,
            compat="override",
            parallel=False,
        )
        return a6.datasets.crop.MultiCropXarrayDataset(
            data_path=path,
            dataset=ds,
            nmb_crops=nmb_crops,
            size_crops=size_crops,
            min_scale_crops=min_scale_crops,
            max_scale_crops=max_scale_crops,
            return_index=True,
        )
    logger.info("Assuming image folder dataset")
    return a6.datasets.crop.MultiCropDataset(
        data_path=path,
        nmb_crops=nmb_crops,
        size_crops=size_crops,
        min_scale_crops=min_scale_crops,
        max_scale_crops=max_scale_crops,
        return_index=True,
    )


path = pathlib.Path(
    "/p/project/training2330/a6/data/ecmwf_era5/nc/era5_pl_2012_2023_12.nc"
)
nmb_crops = (2,)
crops_for_assign = (0, 1)
size_crops = (0.75,)
min_scale_crops = (0.15,)
max_scale_crops = (1.0,)

train_dataset = create_dataset(
    paths=[path],
    is_netcdf=True,
    nmb_crops=nmb_crops,
    size_crops=size_crops,
    min_scale_crops=min_scale_crops,
    max_scale_crops=max_scale_crops,
)

In [None]:
def init_memory(
    dataloader: torch.utils.data.DataLoader,
    model: nn.Module,
    device: torch.device,
    feature_dimensions: int,
    crops_for_assign: tuple[int, float],
    drop_last: bool = False,
):
    size_dataset = len(dataloader.dataset)
    logger.info("Dataset size is %i samples", size_dataset)

    if drop_last:
        size_dataset -= size_dataset % settings.model.batch_size
        logger.warning(
            "Adjusted size of memory per process due to drop_last=True to %i",
            size_dataset,
        )

    logger.info("Processing %i samples", size_dataset)

    indexes = torch.zeros(size_dataset).long().to(device=device)
    embeddings = torch.zeros(
        len(crops_for_assign),
        size_dataset,
        feature_dimensions,
    ).to(device=device)

    start_idx = 0
    with torch.no_grad():
        logger.info("Start initializing the memory banks")
        for index, inputs in dataloader:
            logger.info(
                "Processing %i samples from data indexes %s",
                index.size(0),
                index,
            )
            n_indexes = inputs[0].size(0)
            index = index.to(device=device, non_blocking=True)

            # get embeddings
            outputs = []
            for crop_idx in crops_for_assign:
                inp = inputs[crop_idx].to(device=device, non_blocking=True)
                outputs.append(model(inp)[0])

            # fill the memory bank
            indexes[start_idx : start_idx + n_indexes] = index
            for mb_idx, embedding in enumerate(outputs):
                embeddings[mb_idx][
                    start_idx : start_idx + n_indexes
                ] = embedding
            start_idx += n_indexes
    logger.info(
        "Initialization of the memory banks done with %s local memory indexes",
        indexes.size(),
    )
    return indexes, embeddings


def cluster_memory(
    epoch: int,
    model,
    indexes: torch.Tensor,
    embeddings: torch.Tensor,
    size_dataset: int,
    device: torch.device,
    crops_for_assign: tuple[float, ...],
    nmb_prototypes: tuple[int, ...],
    feature_dimensions: int,
    n_epochs: int,
    nmb_kmeans_iters: int,
    plots_path: pathlib.Path = pathlib.Path("."),
):
    logger.info("Clustering %i samples", size_dataset)

    # j defines which crops are used for the K-means run.
    # E.g. if the number of crops (``self.nmb_mbs``) is 2, and
    # ``self.num_clusters = [30, 30, 30, 30]``, the crops will
    # be used as following:
    #
    # 1. K=30, j=0
    # 2. K=30, j=1
    # 3. K=30, j=0
    # 4. K=30, j=1
    j = 0

    n_heads = len(nmb_prototypes)

    assignments_per_prototype = (torch.zeros(n_heads, size_dataset).long()).to(
        device
    )
    indexes_per_prototype = torch.zeros(n_heads, size_dataset).long().to(device)

    embeddings_per_prototype = torch.zeros(
        n_heads,
        *tuple(embeddings.size()),
    ).to(device)
    distances_per_prototype = torch.zeros(n_heads, size_dataset).to(device)

    with torch.no_grad():
        for i_K, K in enumerate(nmb_prototypes):
            # run k-means

            # init with random samples as centroids from the dataset
            centroids = torch.empty(K, feature_dimensions).to(
                device=device, non_blocking=True
            )

            batch_size = len(embeddings[j])
            random_idx = torch.randperm(batch_size)[:K]
            assert len(random_idx) >= K, (
                f"Please reduce the number of centroids K={K}: "
                f"K must be smaller than batch size {batch_size}"
            )
            centroids = embeddings[j][random_idx]

            for n_iter in range(nmb_kmeans_iters + 1):
                # E step
                dot_products = torch.mm(embeddings[j], centroids.t())
                distances, assignments = dot_products.max(dim=1)

                # finish
                if n_iter == nmb_kmeans_iters:
                    break

                # M step
                where_helper = _get_indices_sparse(assignments.cpu().numpy())
                counts = (
                    torch.zeros(K).to(device=device, non_blocking=True).int()
                )
                emb_sums = torch.zeros(K, feature_dimensions).to(
                    device=device, non_blocking=True
                )
                for k in range(len(where_helper)):
                    if len(where_helper[k][0]) > 0:
                        emb_sums[k] = torch.sum(
                            embeddings[j][where_helper[k][0]],
                            dim=0,
                        )
                        counts[k] = len(where_helper[k][0])
                mask = counts > 0
                centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1)

                # normalize centroids
                centroids = nn.functional.normalize(centroids, dim=1, p=2)

            # Copy centroids to model for forwarding
            getattr(
                model.prototypes,
                "prototypes" + str(i_K),
            ).weight.copy_(centroids)

            logger.info("embeddings[j=%i]: %s", j, embeddings[j].size())

            # Save results to local tensors
            assignments_per_prototype[i_K][indexes] = assignments
            indexes_per_prototype[i_K][indexes] = indexes
            distances_per_prototype[i_K][indexes] = distances
            # For the embeddings, make sure to use j for indexing
            embeddings_per_prototype[i_K][j][indexes] = embeddings[j]

            j_prev = j
            # next memory bank to use
            j = (j + 1) % len(crops_for_assign)

        epoch_comp = epoch + 1

        if (
            # Plot for the first epoch
            epoch_comp == 1
            # Below 100 epochs, plot every 25 epochs,
            or (epoch_comp <= 100 and epoch_comp % 25 == 0)
            # Plot every hundredth epoch
            or epoch_comp % 100 == 0
            # Plot for the last epoch
            or epoch_comp == n_epochs
        ):
            # Save which random samples were used as the centroids.
            assignments_cpu = assignments_per_prototype[-1].cpu()
            a6.plotting.embeddings.plot_embeddings_using_tsne(
                embeddings=embeddings_per_prototype[-1],
                # Use previous j since this represents which crops
                # were used for last cluster iteration.
                j=j_prev,
                assignments=assignments_cpu,
                centroids=random_idx,
                name=f"epoch-{epoch}-embeddings",
                output_dir=plots_path,
            )
            a6.plotting.assignments.plot_abundance(
                assignments=assignments_cpu,
                name=f"epoch-{epoch}-assignments-abundance",
                output_dir=plots_path,
            )
            a6.plotting.transitions.plot_transition_matrix_heatmap(
                assignments_cpu,
                name=f"epoch-{epoch}-transition-heatmap",
                output_dir=plots_path,
            )
            a6.plotting.transitions.plot_transition_matrix_clustermap(
                assignments_cpu,
                name=f"epoch-{epoch}-transition-clustermap",
                output_dir=plots_path,
            )

    return assignments_per_prototype


def _get_indices_sparse(data):
    cols = np.arange(data.size)
    M = csr_matrix(
        (cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size)
    )
    return [np.unravel_index(row.data, data.shape) for row in M]

In [None]:
def train(
    dataloader: torch.utils.data.DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    indexes: torch.Tensor,
    embeddings: torch.Tensor,
    nmb_crops: tuple[float, ...],
    crops_for_assign: tuple[float, ...],
    nmb_prototypes: tuple[int, ...],
    feature_dimensions: int,
    n_epochs: int,
    nmb_kmeans_iters: int,
    temperature: float,
    device: torch.device,
):
    batch_time = averaging.AverageMeter()
    data_time = averaging.AverageMeter()
    losses = averaging.AverageMeter()

    model.train()

    cross_entropy = nn.CrossEntropyLoss()

    assignments = cluster_memory(
        epoch=epoch,
        model=model,
        indexes=indexes,
        embeddings=embeddings,
        size_dataset=len(dataloader.dataset),
        device=device,
        crops_for_assign=crops_for_assign,
        nmb_prototypes=nmb_prototypes,
        feature_dimensions=feature_dimensions,
        nmb_kmeans_iters=nmb_kmeans_iters,
        n_epochs=n_epochs,
    )

    logger.info("Clustering for epoch %i done", epoch)

    end = time.time()
    start_idx = 0
    for it, (idx, inputs) in enumerate(dataloader):
        logger.info("Calculating loss for index %s", idx)

        # measure data loading time
        data_time.update(time.time() - end)

        # ============ multi-res forward passes ... ============
        # Output here returns the output for each head (prototype)
        # and hence has size ``len(settings.model.nmb_prototypes)``.
        emb, output = model(inputs)
        emb = emb.detach()
        bs = inputs[0].size(0)

        if bs == 0:
            raise RuntimeError(
                f"Batch size is zero, loss will be NaN: it={it}, idx={idx}, "
                "inputs[0]={inputs[0]}"
            )

        logger.info("Batch size is %i", bs)

        # ============ deepcluster-v2 loss ... ============
        loss = torch.tensor(0.0)
        for h in range(len(nmb_prototypes)):
            scores = output[h] / temperature
            targets = (
                assignments[h][idx]
                .repeat(sum(nmb_crops))
                .to(device=device, non_blocking=True)
            )
            loss_temp = cross_entropy(scores, targets)
            loss += loss_temp

            if torch.isnan(loss_temp).any() or torch.isnan(loss).any():
                logger.warning(
                    (
                        "Loss is NaN: it=%i, prototype(h)=%i, "
                        "nmb_prototypes=%s, "
                        "idx=%s, assignments=%s, output=%s, targets=%s, "
                        "scores=%s, sum_nmb_crops=%s, loss_temp=%s, loss=%s, "
                        "loss.item()=%s"
                    ),
                    it,
                    h,
                    nmb_prototypes,
                    idx,
                    assignments[h][idx],
                    output[h],
                    targets,
                    scores,
                    sum(nmb_crops),
                    loss_temp,
                    loss,
                    loss.item(),
                )
        loss /= len(nmb_prototypes)

        # ============ backward and optim step ... ============
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ============ update memory banks ... ============
        indexes[start_idx : start_idx + bs] = idx
        for i, crop_idx in enumerate(crops_for_assign):
            embeddings[i][start_idx : start_idx + bs] = emb[  # noqa: E203
                crop_idx * bs : (crop_idx + 1) * bs
            ]
        start_idx += bs

        # ============ misc ... ============
        losses.update(loss.item(), bs)
        batch_time.update(time.time() - end)
        end = time.time()

        if it % 50 == 0:
            logger.info(
                "[EPOCH %i, ITERATION %i] "
                "batch time: %s (%s) "
                "data load time: %s (%s) "
                "loss: %s (%s) "
                "lr: %s",
                epoch,
                it,
                batch_time.val,
                batch_time.avg,
                data_time.val,
                data_time.avg,
                losses.val,
                losses.avg,
                optimizer.state_dict()["param_groups"][0]["lr"],
            )

    logger.info(
        "========= Memory Summary at epoch %i =======\n%s\n",
        epoch,
        torch.cuda.memory_summary(),
    )

    return (epoch, losses.avg), indexes, embeddings

In [None]:
def start(
    train_dataset: a6.datasets.crop.Base,
    epochs: int,
    nmb_crops: tuple[float, ...],
    crops_for_assign: tuple[float, ...],
    model_architecture: a6.models.resnet.Architecture = a6.models.resnet.Architecture.ResNet50,
    batch_size: int = 64,
    drop_last: bool = False,
    hidden_mlp: int = 2048,
    feature_dimensions: int = 128,
    nmb_kmeans_iters: int = 10,
    nmb_prototypes: int = 3,
    nmb_clusters: int = 40,
    base_lr: float = 4.8,
    weight_decay: float = 1e-6,
    temperature: float = 0.1,
):
    nmb_prototypes = [nmb_clusters for _ in range(nmb_prototypes)]
    logger.info("Prototypes for model: %s", nmb_prototypes)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=False,
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True,
        # ``drop_last=True`` gives each device the same amount of samples,
        # but removes some from the clustering.
        drop_last=drop_last,
        worker_init_fn=a6.utils.distributed.set_dataloader_seeds,
    )
    logger.info("Building data done with %s images loaded", len(train_dataset))

    device = torch.device("cuda:0")

    # build model
    model = a6.models.resnet.Models[model_architecture](
        normalize=True,
        in_channels=train_dataset.n_channels,
        hidden_mlp=hidden_mlp,
        output_dim=feature_dimensions,
        nmb_prototypes=nmb_prototypes,
        device=device,
    )
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # Copy model to GPU
    model = model.to(device)

    logger.info(model)
    logger.info("Building model done")

    # build optimizer
    # Should be done after moving the model to GPU
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=base_lr,
        momentum=0.9,
        weight_decay=weight_decay,
    )

    logger.info("Building optimizer done")
    training_stats = logs.Stats("stats.csv", columns=("epoch", "loss"))

    indexes, embeddings = init_memory(
        dataloader=train_loader,
        model=model,
        device=device,
        feature_dimensions=feature_dimensions,
        crops_for_assign=crops_for_assign,
        drop_last=drop_last,
    )

    for epoch in range(epochs):
        # train the network for one epoch
        logger.info(f"============ Starting epoch %i ============", epoch)

        # set sampler
        # train_loader.sampler.set_epoch(epoch)

        # train the network
        scores, indexes, embeddings = train(
            dataloader=train_loader,
            model=model,
            optimizer=optimizer,
            epoch=epoch,
            n_epochs=epochs,
            indexes=indexes,
            embeddings=embeddings,
            nmb_crops=nmb_crops,
            crops_for_assign=crops_for_assign,
            nmb_prototypes=nmb_prototypes,
            feature_dimensions=feature_dimensions,
            nmb_kmeans_iters=nmb_kmeans_iters,
            device=device,
            temperature=temperature,
        )
        training_stats.update(scores)


start(
    train_dataset=train_dataset,
    epochs=10,
    nmb_crops=nmb_crops,
    crops_for_assign=crops_for_assign,
    batch_size=64,
    nmb_prototypes=3,
    nmb_clusters=40,
)