In [None]:
# allows update of external libraries without need to reload package
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import plotly.express
import matplotlib.pyplot as plt
import sklearn.manifold
import gc
import sys
import guppy
import tqdm
import memory_profiler
import torch
import umap

import a2.training.training_hugging
import a2.training.evaluate_hugging
import a2.training.dataset_hugging
import a2.dataset

In [None]:
FOLDER_MODEL_PRETRAINED = "../../models/model_weights/output_rainprediction_simpledeberta/era5/checkpoint-7617/"
FOLDER_MODEL = "microsoft/deberta-v3-small"
# FILE_TWEETS = "../../../maelstrom_bootcamp/Applications/AP2/bootcamp2022_data/tweets/tweets_2017_01_era5_normed_filtered.nc"
FOLDER_TWEETS = "/home/kristian/Projects/a2/data/tweets/"
FILE_TWEETS = FOLDER_TWEETS + "tweets_2017_era5_normed_filtered_predicted_simpledeberta.nc"
FOLDER_EMBEDDINGS = "/home/kristian/Projects/a2/data/embeddings/cls_token/"
FILE_EMBEDDINGS = FOLDER_EMBEDDINGS + "cls_tokenstweets_2017_era5_normed_filtered.nc.npy"
!ls $FILE_TWEETS
!ls $FILE_EMBEDDINGS

In [None]:
ds = a2.dataset.load_dataset.load_tweets_dataset(FILE_TWEETS)
ds["raining"] = (["index"], np.array(ds.tp_h.values > 1e-8, dtype=int))

In [None]:
cls_tokens = np.load(FILE_EMBEDDINGS)

In [None]:
mask = a2.utils.utils.get_random_indices(10000, ds.index.shape[0])

In [None]:
fit = umap.UMAP()
projections = fit.fit_transform(cls_tokens[mask])

In [None]:
n_neighbors = [10, 200]
min_dist = [0.1, 0.8]
metrics = ["correlation", "mahalanobis", "wminkowski"]
n_cols = len(n_neighbors)
n_rows = len(min_dist)
backend = "plotly"
if backend != "plotly":
    fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(15, 5 * n_rows))


def cluster_plot(
    ax: plt.axes,
    n_neighbors: int = 15,
    min_dist: float = 0.1,
    n_components: int = 2,
    metric: str = "euclidean",
    title: str = "",
    backend: str = "plotly",
):
    fit = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        metric=metric,
    )
    projections = fit.fit_transform(cls_tokens[mask])
    if backend == "plotly":
        hover_keys = ["text_normalized", "raining"]
        fig = plotly.express.scatter(
            data_frame=ds.sel(index=mask).to_dataframe(),
            x=projections.T[0],
            y=projections.T[1],
            title=title,
            color="prediction_probability_raining",
            hover_data=hover_keys,
            facet_col="raining",  # hover_data=[get_values(ds_test, x) for x in hover_keys],
            color_continuous_scale="Aggrnyl",
            opacity=0.1,
        )
        fig.show()
    else:
        ax.scatter(x=projections.T[0], y=projections.T[1], alpha=0.1)
        ax.set_title(title)


axes = a2.plotting.utils_plotting.create_axes_grid(len(min_dist), len(n_neighbors))
for i_n, n_ngb in enumerate(n_neighbors):
    for j_d, dist in enumerate(min_dist):
        for m in enumerate(metrics):
            ax = axes[j_d, i_n]
            cluster_plot(
                ax,
                n_neighbors=n_ngb,
                min_dist=dist,
                n_components=2,
                metric=m,
                backend=backend,
                title=f"n_neighbors: {n_ngb}, min_dist: {dist}, metric: {m}",
            )