# Probe label accuracy and homogeneity using a k-NN classifier

This notebook walks through the process of extracting generated embeddings and using them to train a k-NN classifier. This is applied to the input data in order to quantify class separability in the embeddings and to identify labels that are challenging or incorrect.

## Define parameters for embedding extraction

In [3]:
# imports
from gelos.embedding_extraction import extract_embeddings
from gelos.embedding_generation import perturb_args_to_string
import geopandas as gpd
import yaml
from gelos.config import PROJ_ROOT, PROCESSED_DATA_DIR, DATA_VERSION, RAW_DATA_DIR
from gelos.config import REPORTS_DIR, FIGURES_DIR
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.neighbors import KNeighborsClassifier
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
try:
    import faiss  # fast k-NN for high-dimensional data
    _faiss_available = True
except ImportError:
    faiss = None
    _faiss_available = False

In [7]:
n_timesteps = 4
config_yaml_names = [
    "prithvieov2300_noperturb.yaml",
    "prithvieov2600_noperturb.yaml",
    "terramindv1base_noperturb.yaml",
    # add more configs here
]
# extraction_strategy = "All Steps of Middle Patch"
max_workers = None  # set to an int to cap parallelism
n_neighbors = 5
fast_knn_method = "faiss"  # options: "faiss" (if available) or "sklearn"

## Extract embeddings and train a k-NN classifier

In [5]:
def _l2_normalize(x: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
    return x / norms

def _faiss_knn_predict(train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, k: int) -> np.ndarray:
    dim = train_X.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(train_X.astype(np.float32))
    sims, idx = index.search(test_X.astype(np.float32), k)
    top_labels = train_y[idx]
    preds = []
    for row in top_labels:
        vals, counts = np.unique(row, return_counts=True)
        preds.append(vals[np.argmax(counts)])
    return np.array(preds, dtype=train_y.dtype)

def run_config(yaml_name: str, extraction_strategy: str, knn_method: str = "faiss", n_neighbors: int = 5):
    yaml_path = PROJ_ROOT / "gelos" / "configs" / yaml_name
    with open(yaml_path, "r") as f:
        yaml_config = yaml.safe_load(f)
    data_root = RAW_DATA_DIR / DATA_VERSION
    chip_gdf = gpd.read_file(data_root / "gelos_chip_tracker.geojson")
    figures_dir = FIGURES_DIR / DATA_VERSION
    figures_dir.mkdir(exist_ok=True, parents=True)
    model_name = yaml_config["model"]["init_args"]["model"]
    model_title = yaml_config["model"]["title"]
    embedding_extraction_strategies = yaml_config["embedding_extraction_strategies"]
    perturb = yaml_config["data"]["init_args"].get("perturb_bands", None)
    perturb_string = perturb_args_to_string(perturb)
    output_dir = PROCESSED_DATA_DIR / DATA_VERSION / model_name / perturb_string
    embeddings_directories = [item for item in output_dir.iterdir() if item.is_dir()]
    if not embeddings_directories:
        raise FileNotFoundError(f"No embeddings directories found for {yaml_name}")
    embeddings_directory = embeddings_directories[0]
    slice_args = embedding_extraction_strategies[extraction_strategy]
    embeddings, chip_indices = extract_embeddings(embeddings_directory, slice_args=slice_args)
    label_col = "category"  # <-- set to the column in chip_gdf containing class labels
    labels = chip_gdf.iloc[chip_indices][label_col].to_numpy()
    use_faiss = knn_method == "faiss" and _faiss_available
    X = _l2_normalize(embeddings) if use_faiss else embeddings
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    y_pred = np.empty_like(labels)
    for _, (train_idx, test_idx) in enumerate(tqdm(cv.split(X, labels), total=cv.get_n_splits(), desc=f"{yaml_name} folds")):
        if use_faiss:
            y_pred[test_idx] = _faiss_knn_predict(X[train_idx], labels[train_idx], X[test_idx], n_neighbors)
        else:
            knn = KNeighborsClassifier(n_neighbors=n_neighbors)
            knn.fit(X[train_idx], labels[train_idx])
            y_pred[test_idx] = knn.predict(X[test_idx])
    overall_acc = accuracy_score(labels, y_pred)
    cm = confusion_matrix(labels, y_pred, labels=np.unique(labels))
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    per_class = dict(zip(np.unique(labels), per_class_acc))
    return yaml_name, extraction_strategy, model_title, overall_acc, per_class

## Visualize examples with high and low accuracies