In [1]:
import argparse
import logging.config
import os
from collections import defaultdict

from dotenv import load_dotenv

from rescueclip.logging_config import LOGGING_CONFIG

logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)
from pathlib import Path

import numpy as np
from typing import cast, Sequence, List, Any, Literal
import weaviate
from tqdm import tqdm
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.collections.classes.types import Properties, WeaviateProperties
from weaviate.collections.classes.internal import Object
from weaviate.util import generate_uuid5, get_vector

from rescueclip import cuhk
from rescueclip.cuhk import SetNumToImagesMap
from rescueclip.ml_model import (
    CollectionConfig,
    CUHK_Apple_Collection,
    CUHK_Google_Siglip_Base_Patch16_224_Collection,
    CUHK_Google_Siglip_SO400M_Patch14_384_Collection,
    CUHK_laion_CLIP_ViT_bigG_14_laion2B_39B_b160k_Collection,
    CUHK_MetaCLIP_ViT_bigG_14_quickgelu_224_Collection,
    CUHK_ViT_B_32_Collection,
)
from rescueclip.weaviate import WeaviateClientEnsureReady

from pprint import pprint

from embed_cuhk import Metadata, embed_cuhk_dataset

load_dotenv()

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
collection_config = CUHK_Apple_Collection
client = WeaviateClientEnsureReady().get_client()
collection = client.collections.get(collection_config.name)

2025-02-22 13:12:53,635 [INFO] rescueclip.weaviate: Weaviate is ready


In [3]:
QUERY_MAXIMUM_RESULTS = 200_000

number_of_objects: int = collection.aggregate.over_all(total_count=True).total_count # type: ignore
logger.info(f"Number of objects %s", number_of_objects)
assert (
    number_of_objects <= QUERY_MAXIMUM_RESULTS 
), "Ensure docker-compose.yml has QUERY_MAXIMUM_RESULTS to greater than 200_000 or the experiment's results may be inaccurate"

2025-02-22 13:12:53,672 [INFO] __main__: Number of objects 18596


In [5]:
# Train test split
# Remove one random image from each series
INPUT_FOLDER = Path(os.environ["CUHK_PEDES_DATASET"]) / "out"
STOPS_FILE = Path("/scratch3/gbiss/images/CUHK-PEDES-OFFICIAL/caption_all.json")
sets = cuhk.get_sets_new(INPUT_FOLDER, STOPS_FILE)
sets = cuhk.keep_sets_containing_n_images(sets, 4)

set_number_set_list_pairs = list(sets.items())
np.random.shuffle(set_number_set_list_pairs)

_in_sample_series = set_number_set_list_pairs[: len(set_number_set_list_pairs) // 2]
in_sample_series = {set_num: file_names for set_num, file_names in _in_sample_series}
_heldout_series = set_number_set_list_pairs[len(set_number_set_list_pairs) // 2 :]
heldout_series = {set_num: file_names for set_num, file_names in _heldout_series}

logger.info(f"Total series: {len(set_number_set_list_pairs)}")
logger.info(f"In-sample series: {len(in_sample_series)}")
logger.info(f"Held-out series: {len(heldout_series)}")


logger.info("Retrieving the entire database into memory")
results = collection.query.fetch_objects(
    limit=QUERY_MAXIMUM_RESULTS,
    include_vector=True,
    return_properties=True,
)
assert len(results.objects) == number_of_objects, "Expected the entire database to be retrieved"
X = np.array([obj.vector["embedding"] for obj in results.objects])
y_set_labels = np.array([obj.properties["set_number"] for obj in results.objects])
X_indices_in_insample_series = np.array(
    [
        i
        for i, image_metadata in enumerate(results.objects)
        if image_metadata.properties["set_number"] not in heldout_series
    ]
)
X_train = X[X_indices_in_insample_series]
y_train_set_labels = y_set_labels[X_indices_in_insample_series]

logger.info(f"X.shape: {X.shape}")
logger.info(f"y_set_labels.shape: {y_set_labels.shape}")
logger.info(f"X_train.shape: {X_train.shape}")
logger.info(f"y_train_set_labels.shape: {y_train_set_labels.shape}")

2025-02-22 13:13:29,823 [INFO] rescueclip.cuhk: After filtering, using 4649 sets and 18596 images
2025-02-22 13:13:29,826 [INFO] __main__: Total series: 4649
2025-02-22 13:13:29,826 [INFO] __main__: In-sample series: 2324
2025-02-22 13:13:29,827 [INFO] __main__: Held-out series: 2325
2025-02-22 13:13:29,827 [INFO] __main__: Retrieving the entire database into memory
2025-02-22 13:13:33,202 [INFO] __main__: X.shape: (18596, 1024)
2025-02-22 13:13:33,205 [INFO] __main__: y_set_labels.shape: (18596,)
2025-02-22 13:13:33,206 [INFO] __main__: X_train.shape: (9296, 1024)
2025-02-22 13:13:33,206 [INFO] __main__: y_train_set_labels.shape: (9296,)


In [6]:
# def filter_out_neighbors_in_heldout_set(neighbor_objects: Sequence[Object[WeaviateProperties, Any]], heldout_series: SetNumToImagesMap) -> Sequence[Object[WeaviateProperties, None]]:
#     result = []

#     for objectt in neighbor_objects:
#         set_num = objectt.properties["set_number"]
#         file_name = objectt.properties["file_name"]
#         if set_num in heldout_series:
#             if file_name in heldout_series[set_num]:
#                 result.append(objectt)

#     return result

In [None]:
from dataclasses import dataclass
from scipy.spatial.distance import cdist

@dataclass
class ConfusionMatrix:
    tp: int = 0
    tn: int = 0
    fp: int = 0
    fn: int = 0

    def update(self, *, tp: int = 0, tn: int = 0, fp: int = 0, fn: int = 0):
        self.tp += tp
        self.tn += tn
        self.fp += fp
        self.fn += fn

    def as_array(self):
        # Returns a 2x2 array: [[tn, fp], [fn, tp]]
        return [[self.tn, self.fp], [self.fn, self.tp]]
    
    def precision(self):
        if self.tp + self.fp == 0:
            return 0.0
        return self.tp / (self.tp + self.fp)

    def recall(self):
        if self.tp + self.fn == 0:
            return 0.0
        return self.tp / (self.tp + self.fn)

    def f1(self):
        prec = self.precision()
        rec = self.recall()
        if prec + rec == 0:
            return 0.0
        return 2 * (prec * rec) / (prec + rec)

    def __str__(self):
        return f"TP: {self.tp}, TN: {self.tn}, FP: {self.fp}, FN: {self.fn}"

def get_set_number_of_neighbors_within_t_ordered_by_t(vector: np.ndarray, t: float, X: np.ndarray, y_labels: np.ndarray, distance_metric: Literal['cosine']) -> np.ndarray:
    """
    Return the set numbers of the closest neighboring images within t, ordered by t.
    """
    assert X.shape[0] == len(y_labels), "Expected X.shape[0] == len(y_labels)"
    assert len(vector) == X.shape[1], "Expected len(vector) == X.shape[1]"

    if distance_metric == 'cosine':
        # Get the distances between the test image and all the images in X
        distances = cdist(X, [vector], metric='cosine')
        assert distances.shape == (X.shape[0], 1), f"Expected distances.shape == (X.shape[0], 1), got {distances.shape}"
        distances = distances.reshape(-1)
    else:
        raise NotImplementedError(f"distance_metric {distance_metric} not implemented")
    
    assert distances.shape == (X.shape[0], ), f"Expected distances.shape == (X.shape[0], ), expected {(X.shape[0], )}, got {distances.shape}"

    arr = distances

    # Get the indices of the closest images within t
    indices = np.where(distances <= t)[0]
    indices = indices[np.argsort(distances[indices])]

    sorted_distances = distances[indices]
    sorted_y_lables = y_labels[indices]
    
    return sorted_y_lables

def threshold_test(sample_series, X_train, y_train_set_labels, t, confusion_matrix: ConfusionMatrix, is_in_sample=True) -> None:
    for sample_serie in tqdm(sample_series.items(), total=len(sample_series), desc=f"t={t} | In-sample: {is_in_sample}"):
        set_num, images_fps = sample_serie
        for image_fp in images_fps:
            # Get the object for image_fp
            objects = collection.query.fetch_objects(
                filters=Filter.by_property("set_number").equal(set_num)
                & Filter.by_property("file_name").equal(image_fp),
                include_vector=True
            ).objects
            assert len(objects) == 1, "Expected 1 object"
            objectt = objects[0]
            
            # Get the neighbors within t
            neighboring_sets = get_set_number_of_neighbors_within_t_ordered_by_t(np.array(objectt.vector['embedding']), t, X_train, y_train_set_labels, 'cosine')
            assert len(neighboring_sets.shape) == 1, "Expected neighboring_sets to be a 1D array"

            if is_in_sample:
                assert len(neighboring_sets) >= 1, "Expected at least the test image itself, because it is in the DB"
                assert neighboring_sets[0] == set_num, "Expected the best distance set to be the test image itself"
                neighboring_sets = neighboring_sets[1:]

            if is_in_sample:
                if set_num in neighboring_sets:
                    confusion_matrix.tp += 1
                else:
                    confusion_matrix.fp += 1
            else:
                if len(neighboring_sets) == 0:
                    confusion_matrix.tn += 1
                else:
                    confusion_matrix.fn += 1

cm = ConfusionMatrix()
t = 0.1

threshold_test(in_sample_series, X_train, y_train_set_labels, t, cm, is_in_sample=True)
threshold_test(heldout_series, X_train, y_train_set_labels, t, cm, is_in_sample=False)


t=0.1 | In-sample: True: 100%|██████████| 2324/2324 [02:47<00:00, 13.91it/s]
t=0.1 | In-sample: False: 100%|██████████| 2325/2325 [02:46<00:00, 13.94it/s]


In [8]:
cm

ConfusionMatrix(tp=4788, tn=5527, fp=4508, fn=3773)