In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn.functional as F
from umap import UMAP

modes = ["train", "val", "test"]
embeddings: dict = {}
samples: dict = {}
for mode in modes:
    embed_output: dict = torch.load(
        f"../outputs/embeddings/OOPS_cs_{mode}_16@7_5_Qwen3-VL-Embedding-2B_448.pt"
    )
    embeddings[mode] = embed_output["embeddings"]
    samples[mode] = pd.DataFrame(embed_output["samples"])

In [None]:
# load the datasets
from hydra import compose, initialize

from falldet.data.video_dataset_factory import get_video_datasets
from falldet.schemas import InferenceConfig

with initialize(version_base=None, config_path="../config/"):
    cfg = compose(config_name="inference_config")
    cfg = InferenceConfig.model_validate(cfg)

In [None]:
datasets = {}
kwargs = {"return_individual": True, "size": 448, "seed": None}
cfg.dataset.vid_frame_count = 9
for mode in modes:
    datasets[mode] = get_video_datasets(cfg, mode=mode, **kwargs)["individual"]["OOPS_cs"]

In [None]:
embeddings["train"].shape, embeddings["test"].shape

In [None]:
def compute_cosine_similarity(queries: torch.Tensor, corpus: torch.Tensor) -> torch.Tensor:
    """
    Compute the cosine similarity matrix for the given embeddings.

    Args:
        queries: A tensor of shape (num_queries, embedding_dim)
        corpus: A tensor of shape (num_corpus, embedding_dim)

    Returns:
        A tensor of shape (num_samples, num_samples) containing the cosine similarity values.
    """
    # Normalize the embeddings to unit vectors
    queries_normalized = F.normalize(queries, p=2, dim=1)
    corpus_normalized = F.normalize(corpus, p=2, dim=1)
    # Compute the cosine similarity matrix
    cosine_similarity_matrix = queries_normalized @ corpus_normalized.T

    return cosine_similarity_matrix

In [None]:
compute_cosine_similarity(embeddings["test"], embeddings["train"]).shape

In [None]:
def retrieve_top_k_nearest_neighbors(
    queries: torch.Tensor, corpus: torch.Tensor, k: int
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Retrieve the top-k nearest neighbors for each query embedding based on cosine similarity.

    Args:
        queries: A tensor of shape (num_queries, embedding_dim)
        corpus: A tensor of shape (num_corpus, embedding_dim)
        k: The number of nearest neighbors to retrieve for each query.
    Returns:
        A tensor of shape (num_queries, k) containing the indices of the top-k nearest
        neighbors in the corpus for each query.
        A tensor of shape (num_queries, k) containing the cosine similarity between each query and its top-k nearest neighbors.
    """
    # Compute the cosine similarity matrix
    cosine_similarity_matrix = compute_cosine_similarity(queries, corpus)

    # Retrieve the top-k nearest neighbors for each query
    top_k_similarity_scores, top_k_neighbors = torch.topk(cosine_similarity_matrix, k=k, dim=1)

    return top_k_neighbors, top_k_similarity_scores

In [None]:
nns, similarities = retrieve_top_k_nearest_neighbors(embeddings["test"], embeddings["train"], k=5)

In [None]:
nns[0], similarities[0]

In [None]:
from falldet.visualization import visualize_video


def visualize_query_and_neighbors(
    query_index: int,
    neighbor_indices: torch.Tensor,
    similarities: torch.Tensor,
    query_mode: str = "test",
    neighbor_mode: str = "train",
):
    """
    Visualize the query video and its top-k nearest neighbors.

    Args:
        query_index: The index of the query video in the samples DataFrame.
        neighbor_indices: A tensor containing the indices of the top-k nearest neighbors in the corpus.
        similarities: A tensor containing the similarity scores for each neighbor.
        samples: A DataFrame containing the sample information, including video paths and labels.
        num_frames: The number of frames to visualize for each video.
    """
    # Get the query sample information
    query_sample = samples[query_mode].iloc[query_index]
    query_label = query_sample["label_str"]

    print(f"Query Video (Index: {query_index}, Label: {query_label}):")
    visualize_video(idx=query_index, dataset=datasets[query_mode], nrow=9)
    plt.show()

    # Visualize the nearest neighbors
    for i, (neighbor_index, similarity) in enumerate(zip(neighbor_indices, similarities)):
        neighbor_sample = samples[neighbor_mode].iloc[neighbor_index.item()]
        neighbor_label = neighbor_sample["label_str"]
        print(
            f"Neighbor {i + 1} (Index: {neighbor_index}, Label: {neighbor_label}, Similarity: {similarity:.4f}):"
        )
        visualize_video(idx=neighbor_index, dataset=datasets[neighbor_mode], nrow=9)
        plt.show()

In [None]:
idx = 4
visualize_query_and_neighbors(
    query_index=idx, neighbor_indices=nns[idx], similarities=similarities[idx]
)

In [None]:
from matplotlib import colormaps


def plot_2d_scatter(
    embeddings: torch.Tensor, labels: list[str], title: str = "UMAP 2D Projection of Embeddings"
) -> None:
    """
    Plot a 2D scatter plot of the given embeddings colored by their labels.

    Args:
        embeddings: A tensor of shape (num_samples, embedding_dim) containing the embeddings to plot.
        labels: A list of strings of length (num_samples,) containing the labels for coloring the points.
        title: The title of the plot.
    """
    # Reduce dimensionality to 2D using UMAP
    umap = UMAP(n_components=2)
    proj = umap.fit_transform(embeddings.cpu().numpy())

    unique_labels = sorted(set(labels))

    cm = colormaps["tab20"]
    color_map = {label: cm(i / len(unique_labels)) for i, label in enumerate(unique_labels)}
    # Create a scatter plot
    fig, ax = plt.subplots(figsize=(10, 8))
    for label in unique_labels:
        mask = labels == label
        ax.scatter(proj[mask, 0], proj[mask, 1], c=[color_map[label]], label=label, s=40, alpha=0.8)

    ax.set_title(title)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    ax.legend(title="Label", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    return ax

In [None]:
plot_2d_scatter(
    embeddings["test"], samples["test"].label_str, title="UMAP Projection of Test Embeddings"
)

In [None]:
def plot_3d_scatter(embeddings: torch.Tensor, labels: list[str], title: str) -> None:
    """
    Plot a 3D scatter plot of the given embeddings colored by their labels.

    Args:
        embeddings: A tensor of shape (num_samples, embedding_dim) containing the embeddings to plot.
        labels: A list of strings of length (num_samples,) containing the labels for coloring the points.
        title: The title of the plot.
    """
    # Reduce dimensionality to 3D using UMAP
    reducer_3d = UMAP(n_components=3, random_state=42, n_jobs=1)
    proj_3d_mpl = reducer_3d.fit_transform(embeddings.numpy())

    fig_3d = plt.figure(figsize=(10, 8))
    ax_3d = fig_3d.add_subplot(111, projection="3d")
    unique_labels = sorted(set(labels))
    cm = colormaps["tab20"]
    color_map = {
        label_name: cm(i / len(unique_labels)) for i, label_name in enumerate(unique_labels)
    }

    for label_name in unique_labels:
        mask_3d_mpl = labels == label_name
        ax_3d.scatter(
            proj_3d_mpl[mask_3d_mpl, 0],
            proj_3d_mpl[mask_3d_mpl, 1],
            proj_3d_mpl[mask_3d_mpl, 2],
            c=[color_map[label_name]],
            label=label_name,
            s=30,
            alpha=0.8,
        )

    ax_3d.legend(title="Label", bbox_to_anchor=(1.15, 1), loc="upper left", fontsize=7)
    ax_3d.set_title("UMAP 3D Projection of Embeddings")
    ax_3d.set_xlabel("UMAP 1")
    ax_3d.set_ylabel("UMAP 2")
    ax_3d.set_zlabel("UMAP 3")
    plt.tight_layout()
    return ax_3d

In [None]:
plot_3d_scatter(
    embeddings["test"], samples["test"].label_str, title="UMAP Projection of Test Embeddings"
)

In [None]:
import plotly.express as px
import plotly.graph_objects as go

split = "test"

labels = samples[split].label_str
unique_labels = sorted(set(labels))
embeddings_train = embeddings[split]

colors = px.colors.qualitative.Dark24

fig_3d_interactive = go.Figure()

proj_3d = UMAP(n_components=3, random_state=42, n_jobs=1).fit_transform(
    embeddings_train.cpu().numpy()
)
for i, label_name in enumerate(unique_labels):
    mask_3d = labels == label_name
    fig_3d_interactive.add_trace(
        go.Scatter3d(
            x=proj_3d[mask_3d, 0],
            y=proj_3d[mask_3d, 1],
            z=proj_3d[mask_3d, 2],
            mode="markers",
            marker=dict(size=5, opacity=0.8, color=colors[i % len(colors)]),
            name=label_name,
        )
    )

fig_3d_interactive.update_layout(
    title="UMAP 3D Projection of Embeddings",
    scene=dict(xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3"),
    width=900,
    height=700,
    legend_title="Label",
)
fig_3d_interactive.show()