In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from preprocess import preprocess_data
from embeddings import Embeddings
from curate_using_LCA import curate_using_LCA, generate_wgtr_calibration_ground_truth
from tools import load_pickle, print_intersect_stats, get_config
from cluster_validator import ClusterValidator
import ga_driver
from init_logger import init_logger
import argparse
from sklearn.metrics import pairwise_distances_chunked


def get_score_from_cosine_distance(cosine_dist):
    """Convert cosine distance to a score."""
    return 1 - cosine_dist * 0.5


def get_topk_acc(labels_q, labels_db, dists, topk):
    """Compute top-k accuracy for label queries."""
    return sum(get_topk_hits(labels_q, labels_db, dists, topk)) / len(labels_q)


def get_topk_hits(labels_q, labels_db, dists, topk):
    """Return whether the correct label is in the top-k closest predictions."""
    indices = np.argsort(dists, axis=1)
    top_labels = np.array(labels_db)[indices[:, :topk]]
    hits = (top_labels.T == labels_q).T
    return np.sum(hits[:, :topk+1], axis=1) > 0


def get_top_ks(q_pids, distmat, ks=[1, 3, 5, 10]):
    """Calculate top-k accuracies for the given distance matrix."""
    return [(k, get_topk_acc(q_pids, q_pids, distmat, k)) for k in ks]


def calculate_distances(embeddings, ids, uuids, reduce_func):
    """Calculate pairwise distances and apply a reduction function."""
    print(f"Calculating distances for {len(embeddings)} embeddings and {len(ids)} IDs...")

    chunks = pairwise_distances_chunked(
        embeddings,
        metric='cosine',
        reduce_func=reduce_func,
        n_jobs=-1
    )
    return np.concatenate(list(chunks), axis=0)


def prepare_reduce_func():
    """Prepare the distance reduction function that applies cosine distance transformation."""
    def reduce_func(distmat, start):
        distmat = 1 - get_score_from_cosine_distance(distmat)
        np.fill_diagonal(distmat, np.inf)
        return distmat
    return reduce_func


def get_stats(df, filter_key, embeddings):
    """Compute statistics based on distance matrix and top-k accuracy."""
    uuids = {i: row['uuid_x'] for i, row in df.iterrows()}
    ids = list(uuids.keys())

    reduce_func = prepare_reduce_func()

    # Calculate distance matrix
    distmat = calculate_distances(embeddings, ids, uuids, reduce_func)

    # Map labels and compute top-k accuracies
    labels = [df.loc[df['uuid_x'] == uuids[id], filter_key].values[0] for id in ids]
    return get_top_ks(labels, distmat, ks=[1, 3, 5, 10])


In [2]:
name_keys

NameError: name 'name_keys' is not defined

In [9]:

init_logger()
config = get_config("configs/config_zebra.yaml")
# run(config)
"""Main pipeline function to process embeddings and calculate statistics."""
data_params = config['data']

# Load embeddings and UUIDs
embeddings, uuids = load_pickle(data_params['embedding_file'])

# Preprocess data
name_keys = data_params['name_keys']
# name_keys = ['individual_uuid']
# name_keys = ['name']
filter_key = '__'.join(name_keys)

df = preprocess_data(
    data_params['annotation_file'],
    name_keys=name_keys,
    convert_names_to_ids=True,
    viewpoint_list=data_params['viewpoint_list'],
    n_filter_min=data_params['n_filter_min'],
    n_filter_max=data_params['n_filter_max'],
    images_dir=data_params['images_dir'],
    embedding_uuids=uuids
)

print_intersect_stats(df, individual_key=filter_key)

# Filter dataframe by UUIDs
filtered_df = df[df['uuid_x'].isin(uuids)]
print('     ', len(filtered_df), 'annotations remain after filtering by the provided embeddings')
filtered_embeddings = [embeddings[uuids.index(uuid)] for uuid in filtered_df['uuid_x']]
print('     ', len(filtered_embeddings), 'embeddings remain after filtering by the provided annotations')

# Compute statistics
topk_results = get_stats(filtered_df, filter_key, filtered_embeddings)

print(f"Statistics: {', '.join([f'top-{k}: {100*v:.2f}%' for (k, v) in topk_results])}")


INFO   2024-09-19 12:24:43,519 [          tools.py: 14] Loading config from path: configs/config_zebra.yaml
INFO   2024-09-19 12:24:43,519 [          tools.py: 14] Loading config from path: configs/config_zebra.yaml
INFO   2024-09-19 12:24:43,519 [          tools.py: 14] Loading config from path: configs/config_zebra.yaml
INFO   2024-09-19 12:24:43,519 [          tools.py: 14] Loading config from path: configs/config_zebra.yaml


INIT
Merging on image uuid
** Loaded /ekaterina/work/data/zebra/annotations/zebra.json **
      Found 3856 annotations
      3846 annotations remain after filtering by given uuids
      3846 annotations remain after filtering by viewpoint list ['right', 'left']
      3846 annotations remain after filtering by min 2 per name__viewpoint


INFO   2024-09-19 12:24:44,670 [          tools.py:177] ** Dataset statistcs **
INFO   2024-09-19 12:24:44,670 [          tools.py:177] ** Dataset statistcs **
INFO   2024-09-19 12:24:44,670 [          tools.py:177] ** Dataset statistcs **
INFO   2024-09-19 12:24:44,670 [          tools.py:177] ** Dataset statistcs **
INFO   2024-09-19 12:24:44,674 [          tools.py:178]  - Counts: 
INFO   2024-09-19 12:24:44,674 [          tools.py:178]  - Counts: 
INFO   2024-09-19 12:24:44,674 [          tools.py:178]  - Counts: 
INFO   2024-09-19 12:24:44,674 [          tools.py:178]  - Counts: 
INFO   2024-09-19 12:24:44,681 [          tools.py:182]  ---- number of individuals: 413
INFO   2024-09-19 12:24:44,681 [          tools.py:182]  ---- number of individuals: 413
INFO   2024-09-19 12:24:44,681 [          tools.py:182]  ---- number of individuals: 413
INFO   2024-09-19 12:24:44,681 [          tools.py:182]  ---- number of individuals: 413
INFO   2024-09-19 12:24:44,685 [          tools.py:1

      3846 annotations remain after filtering by max 100 per name__viewpoint
      3846 annotations remain after filtering by the provided embeddings
      3846 embeddings remain after filtering by the provided annotations
Calculating distances for 3846 embeddings and 3846 IDs...
Statistics: top-1: 78.94%, top-3: 89.86%, top-5: 91.89%, top-10: 92.80%


In [3]:
df.to_csv("zebra_df.csv")

NameError: name 'df' is not defined