In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from collections import Counter
from ipywidgets import interact, interactive, IntSlider
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [None]:
SAMPLES_FILE = "/home/andrej/data/cardiomegaly_frontal_and_standard_validation_dataset.csv"
EMBEDDINGS_FILE = "/home/andrej/data/cardiomegaly_frontal_and_standard_embeddings_epoch_4_20250227_013258_utc.pt"

In [None]:
def to_numpy(tensor):
    return tensor.to(torch.float32).numpy() if tensor.dtype == torch.bfloat16 else tensor.numpy()

def plot_tsne_with_df_labels(perplexity=30, df=None):
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings_np)

    if df is not None:
        numeric_labels = df["labels"].apply(lambda x: 1 if x == "['Cardiomegaly']" else 0)
    else:
        raise ValueError("A valid dataframe with a 'labels' column must be provided.")

    plt.figure(figsize=(10, 7))
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], s=10, alpha=0.8, c=numeric_labels, cmap="viridis")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")
    plt.colorbar(scatter, label="Label (0: [], 1: ['Cardiomegaly'])")
    plt.show()

    return embeddings_2d

def get_events_in_region(embeddings_2d, df, x_range, y_range, target_label="[]"):
    x_min, x_max = x_range
    y_min, y_max = y_range

    region_mask = (
        (embeddings_2d[:, 0] >= x_min) &
        (embeddings_2d[:, 0] <= x_max) &
        (embeddings_2d[:, 1] >= y_min) &
        (embeddings_2d[:, 1] <= y_max)
    )

    indices_region = np.where(region_mask)[0]
    label_mask = (df["labels"] == target_label).values
    indices_final = np.intersect1d(indices_region, np.where(label_mask)[0])
    filtered_df = df.iloc[indices_final]

    return indices_final, filtered_df

In [None]:
print(f"Reading samples file {SAMPLES_FILE}")
df = pd.read_csv(SAMPLES_FILE)
print(df.index.size)
print(Counter(df["labels"]))

In [None]:
print(f"Reading embeddings file {EMBEDDINGS_FILE}")
embeddings = torch.load(EMBEDDINGS_FILE, map_location=torch.device("cpu"))

if isinstance(embeddings, torch.Tensor):
    embeddings_np = to_numpy(embeddings)
elif isinstance(embeddings, dict):
    embeddings_np = {key: to_numpy(tensor) for key, tensor in embeddings.items()}
elif isinstance(embeddings, list):
    embeddings_np = [to_numpy(tensor) for tensor in embeddings]
else:
    raise TypeError(f"Unsupported embeddings format: {type(embeddings)}")

print(type(embeddings_np))
print(embeddings_np.shape)

In [None]:
embeddings_2d = plot_tsne_with_df_labels(perplexity=10, df=df)

In [None]:
get_events_in_region(embeddings_2d, df, x_range=(0, 25), y_range=(0, 25), target_label="[]")