In [5]:
# viz_embeddings_fiftyone.py
import time
import torch
import numpy as np
import fiftyone as fo
import fiftyone.brain as fob
from fiftyone import Classification, Classifications
from torch.utils.data import TensorDataset, DataLoader


In [6]:
def _load_concat_dataset(embed_paths, label_paths):
    """
    Load multiple .pt files and concatenate along dim 0
    """
    X_list = [torch.load(p) for p in embed_paths]
    y_list = [torch.load(p) for p in label_paths]
    
    X = torch.cat(X_list, dim=0)
    y = torch.cat(y_list, dim=0)
    
    return TensorDataset(X, y)

In [7]:
val_ds = _load_concat_dataset(
    [
        '/home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/test_data.pt',
        '/home/free4ky/projects/chest-diseases/data/preprocessed_val_20/validation_data.pt'
    ],
    [
        '/home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/test_labels.pt',
        '/home/free4ky/projects/chest-diseases/data/preprocessed_val_20/validation_labels.pt'
    ]
)

In [8]:

# --- 1) load your .pt files (embeddings: [N,512], labels: [N,20]) ---


emb = val_ds.tensors[0]
labels = val_ds.tensors[1]

# to numpy (safe for GPU tensors)
if isinstance(emb, torch.Tensor):
    emb = emb.detach().cpu().numpy()
if isinstance(labels, torch.Tensor):
    labels = labels.detach().cpu().numpy()

# basic validation
if emb.ndim != 2:
    raise ValueError(f"embeddings must be 2D (n_samples x n_dims); got shape {emb.shape}")
if labels.ndim != 2:
    raise ValueError(f"labels must be 2D (n_samples x n_classes); got shape {labels.shape}")
n_samples, emb_dim = emb.shape
if labels.shape[0] != n_samples:
    raise ValueError("number of rows in embeddings and labels must match")

n_classes = labels.shape[1]

# --- 2) create class names (or provide your own list) ---
# If you already have human-readable class names, replace this list
class_names = [f"class_{i}" for i in range(n_classes)]

# --- 3) convert label vectors to lists of class-names per sample ---
# If your labels are probabilities/floats, we threshold at 0.5;
# if they're already binary (0/1) this also works.
if np.issubdtype(labels.dtype, np.floating):
    label_mask = labels >= 0.5
else:
    # ints / bools
    label_mask = labels.astype(bool)

label_lists = [
    [class_names[j] for j, present in enumerate(row) if present]
    for row in label_mask
]

# optional: a "top label" per sample for simpler single-color visualizations
top_labels = [lst[0] if len(lst) else "NONE" for lst in label_lists]

# --- 4) create a FiftyOne dataset (samples with a multilabel field) ---
# Use a timestamped name to avoid clobbering existing datasets
dataset_name = f"embeddings_viz_{int(time.time())}"
dataset = fo.Dataset(dataset_name)

# Build fo.Sample objects in the same order as the rows in `emb`
samples = []
for i, labs in enumerate(label_lists):
    # sample filepath is a placeholder — replace with your real image paths if you have them
    sample = fo.Sample(filepath=f"sample_{i}.jpg")
    # store multilabels in a label field (optional, useful in the app)
    sample["labels_mult"] = Classifications(
        classifications=[Classification(label=l) for l in labs]
    )
    samples.append(sample)

dataset.add_samples(samples)  # add all at once (fast)

# --- 5) compute a 2D visualization (UMAP) from your precomputed embeddings ---
# We pass the raw embeddings array aligned to the dataset order
# method: "umap" (default options can be tuned via kwargs)
results = fob.compute_visualization(
    dataset,
    embeddings=emb,   # numpy array of shape (n_samples, emb_dim)
    method="umap",    # 'umap' | 'tsne' | 'pca' (FiftyOne supports these)
    num_dims=2,       # 2D for interactive plots
)

# --- 6) show an interactive scatterplot colored by your multilabels ---
# Option A: color by the full multilabel lists (each point may appear under multiple classes)
plot = results.visualize(labels=label_lists)   # labels can be a list-of-lists
plot.show()  # opens an interactive window / notebook renderer

# Option B: color by a single label per sample (e.g., top_labels)
# plot = results.visualize(labels=top_labels)
# plot.show()

# --- 7) (optional) index the 2D points back into the dataset for lasso selection & querying ---
# store points into a sample field "umap_points"; create a spatial index for fast queries
results.index_points(points_field="umap_points", create_index=True)

# --- 8) (optional) attach the plot to a FiftyOne App session for synced exploration ---
session = fo.launch_app(dataset)          # launches the FiftyOne App in your browser
session.plots["embeddings_umap"] = plot   # adds the interactive plot to the App
print(f"Dataset: {dataset_name} — open FiftyOne App to explore samples & the embeddings plot.")


FiftyOneConfigError: MongoDB could not be installed on your system. Please define a `database_uri` in your `fiftyone.core.config.FiftyOneConfig` to connect to yourown MongoDB instance or cluster 