In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import argparse
import h5py
import numpy as np
import pandas as pd
from sklearn import preprocessing
import os
import time
from pathlib import Path
from urllib.request import urlretrieve
import logging
from li.Baseline import Baseline
from li.LearnedIndex import LearnedIndex
from li.utils import save_as_pickle
from li.model import data_X_to_torch


In [3]:
np.random.seed(2023)

logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(levelname)-5.5s][%(name)-.20s] %(message)s'
)
LOG = logging.getLogger(__name__)


In [4]:
kind='pca32v2'
size='100K'
key='pca32'

In [5]:
data = np.array(h5py.File(os.path.join("data", kind, size, "dataset.h5"), "r")[key])
queries = np.array(h5py.File(os.path.join("data", kind, size, "query.h5"), "r")[key])

In [6]:
kind_search = 'clip768v2'
key_search = 'emb'
data_search = np.array(
    h5py.File(os.path.join("data", kind_search, size, "dataset.h5"), "r")[key_search]
)

In [7]:
queries_search = np.array(
    h5py.File(os.path.join("data", kind_search, size, "query.h5"), "r")[key_search]
)

In [8]:
import torch
from torch import nn
import torch.nn.functional as nnf
import numpy as np
from li.Logger import Logger
from typing import Tuple
import torch.utils.data

torch.manual_seed(2023)
np.random.seed(2023)

class Model(nn.Module):
    def __init__(self, input_dim=768, output_dim=1000, model_type=None):
        super().__init__()
        if model_type == 'MLP':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, output_dim)
            )
        if model_type == 'MLP-2':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 64),
                torch.nn.ReLU(),
                torch.nn.Linear(64, output_dim)
            )
        if model_type == 'MLP-3':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, output_dim)
            )
        if model_type == 'MLP-4':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 512),
                torch.nn.ReLU(),
                torch.nn.Linear(512, output_dim)
            )
        if model_type == 'MLP-5':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, output_dim)
            )
        if model_type == 'MLP-6':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 32),
                torch.nn.ReLU(),
                torch.nn.Linear(32, output_dim)
            )
        if model_type == 'MLP-7':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 16),
                torch.nn.ReLU(),
                torch.nn.Linear(16, output_dim)
            )
        if model_type == 'MLP-8':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 8),
                torch.nn.ReLU(),
                torch.nn.Linear(8, output_dim)
            )
        if model_type == 'MLP-9':
            self.layers = torch.nn.Sequential(
                torch.nn.Linear(input_dim, 8),
                torch.nn.ReLU(),
                torch.nn.Linear(input_dim, 16),
                torch.nn.ReLU(),
                torch.nn.Linear(16, output_dim)
            )
        self.n_output_neurons = output_dim

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        outputs = self.layers(x)
        return outputs


def data_X_to_torch(data) -> torch.FloatTensor:
    """ Creates torch training data."""
    data_X = torch.from_numpy(np.array(data).astype(np.float32))
    return data_X


def data_to_torch(data, labels) -> Tuple[torch.FloatTensor, torch.LongTensor]:
    """ Creates torch training data and labels."""
    data_X = data_X_to_torch(data)
    data_y = torch.as_tensor(torch.from_numpy(labels), dtype=torch.long)
    return data_X, data_y


def get_device() -> torch.device:
    """ Gets the `device` to be used by torch.
    This arugment is needed to operate with the PyTorch model instance.

    Returns
    ------
    torch.device
        Device
    """
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    torch.backends.cudnn.benchmark = True
    return device


class NeuralNetwork(Logger):
    """ The neural network class corresponding to every inner node.

    Parameters
    ----------
    input_dim : int
        The input dimension.
    output_dim : int
        The output dimension.
    loss : torch.nn, optional
        The loss function, the default is torch.nn.CrossEntropyLoss.
    lr : float, optional
        The learning rate, the default is 0.001.
    model_type : str, optional
        The model type, the default is 'MLP'.
    class_weight : torch.FloatTensor, optional
        The class weights, the default is None.
    """
    def __init__(
        self,
        input_dim,
        output_dim,
        loss=torch.nn.CrossEntropyLoss,
        lr=0.1,
        model_type='MLP',
        class_weight=None
    ):
        self.device = get_device()
        self.model = Model(input_dim, output_dim, model_type=model_type).to(self.device)
        if not isinstance(class_weight, type(None)):
            self.loss = loss(weight=class_weight.to(self.device))
        else:
            self.loss = loss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def train(
        self,
        data_X: torch.FloatTensor,
        data_y: torch.LongTensor,
        epochs=500,
        logger=None
    ):
        step = epochs // 10
        losses = []
        if logger:
            logger.info(f'Epochs: {epochs}, step: {step}')
        for ep in range(epochs):
            pred_y = self.model(data_X.to(self.device))
            curr_loss = self.loss(pred_y, data_y.to(self.device))
            if ep % step == 0 and ep != 0:
                if logger:
                    logger.info(f'Epoch {ep} | Loss {curr_loss.item()}')
            losses.append(curr_loss.item())

            self.model.zero_grad()
            curr_loss.backward()

            self.optimizer.step()
        return losses

    def train_batch(
        self,
        dataset,
        epochs=5,
        logger=None
    ):
        step = epochs // 10
        step = step if step > 0 else 1
        losses = []
        if logger:
            logger.info(f'Epochs: {epochs}, step: {step}')
        for ep in range(epochs):
            for data_X, data_y in iter(dataset):
                pred_y = self.model(data_X.to(self.device))
                curr_loss = self.loss(pred_y, data_y.to(self.device))

            if ep % step == 0 and ep != 0:
                if logger:
                    logger.info(f'Epoch {ep} | Loss {curr_loss.item():.5f}')
            losses.append(curr_loss.item())

            self.model.zero_grad()
            curr_loss.backward()

            self.optimizer.step()
        return losses

    def predict(self, data_X: torch.FloatTensor):
        """ Collects predictions for multiple data points (used in structure building)."""
        self.model = self.model.to(self.device)
        self.model.eval()

        all_outputs = torch.tensor([], device=self.device)
        with torch.no_grad():
            outputs = self.model(data_X.to(self.device))
            all_outputs = torch.cat((all_outputs, outputs), 0)

        _, y_pred = torch.max(all_outputs, 1)
        return y_pred.cpu().numpy()

    def predict_proba(self, data_X: torch.FloatTensor):
        """ Collects predictions for a single data point (used in query predictions)."""
        self.model = self.model.to(self.device)
        self.model.eval()

        with torch.no_grad():
            outputs = self.model(data_X.to(self.device))

        if outputs.dim() == 1:
            dim = 0
        else:
            dim = 1
        prob = nnf.softmax(outputs, dim=dim)
        probs, classes = prob.topk(prob.shape[1])

        return probs.cpu().numpy(), classes.cpu().numpy()


class LIDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_x, dataset_y):
        self.dataset_x, self.dataset_y = data_to_torch(dataset_x, dataset_y)

    def __len__(self):
        return self.dataset_x.shape[0]

    def __getitem__(self, idx):
        return self.dataset_x[idx-1], self.dataset_y[idx-1]


In [9]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import time

def pairwise_cosine_threshold(x, y, threshold, cat_idxs, k=10):
    s = time.time()
    result = 1-cosine_similarity(x, y)
    t_pure_seq_search = time.time() - s
    # create an array of consisten shapes
    #print(result.shape)
    #print(threshold.shape, threshold[cat_idxs].shape)
    thresh_consistent = np.repeat(threshold[cat_idxs, np.newaxis], result.shape[1], 1)
    relevant_dists = np.where(result < thresh_consistent)
    # filter the relevant object ids
    try:
        relevant_object_ids = np.unique(relevant_dists[1])
        max_idx = relevant_object_ids.shape[0]
        if max_idx == 0:
            return None, t_pure_seq_search
    except ValueError:
        # There is no distance below the threshold, we can return
        return None, t_pure_seq_search
    max_idx = max_idx if max_idx > k else k
    # output array filled with some large value
    output_arr = np.full(shape=(result.shape[0], max_idx), fill_value=10_000, dtype=np.float)
    #index_df = pd.DataFrame(relevant_dists[0])
    #index_df = pd.DataFrame(relevant_dists[0], relevant_dists[1]).reset_index()
    # create indexes to store the relevant distances
    #index_df['mapping'] = index_df.groupby('index').ngroup()
    mapping = dict(zip(relevant_object_ids, np.arange(relevant_object_ids.shape[0])))
    # tried also with np.vectorize, wasn't faster
    output_arr_2nd_dim = np.array([mapping[x] for x in relevant_dists[1]])
    to_be_added = result[relevant_dists[0], relevant_dists[1]]
    # populate the output array
    output_arr[relevant_dists[0], output_arr_2nd_dim] = to_be_added
    return output_arr, relevant_object_ids, t_pure_seq_search

In [10]:
import numpy as np
from li.Logger import Logger
from li.utils import pairwise_cosine
import time
import torch
import torch.utils.data
import faiss
from tqdm import tqdm
import numpy as np
torch.manual_seed(2023)
np.random.seed(2023)

class LearnedIndex(Logger):

    def __init__(self):
        self.pq = []
        self.model = None

    def search(
        self,
        data_navigation,
        queries_navigation,
        data_search,
        queries_search,
        pred_categories,
        n_buckets=1,
        k=10,
        use_threshold=False
    ):
        """ Search for k nearest neighbors for each query in queries.

        Parameters
        ----------
        queries : np.array
            Queries to search for.
        data : np.array
            Data to search in.
        n_buckets : int
            Number of most similar buckets to search in.
        k : int
            Number of nearest neighbors to search for.

        Returns
        -------
        dists : np.array
            Array of shape (queries.shape[0], k) with distances to nearest neighbors for each query.
        anns : np.array
            Array of shape (queries.shape[0], k) with nearest neighbors for each query.
        time : float
            Time it took to search.
        """
        assert self.model is not None, 'Model is not trained, call `build` first.'
        s = time.time()
        _, pred_proba_categories = self.model.predict_proba(
            data_X_to_torch(queries_navigation)
        )
        t_inference = time.time() - s
        anns_final = None
        dists_final = None
        # sorts the predictions of a bucket for each query, ordered by lowest probability
        data_navigation['category'] = pred_categories

        # iterates over the predicted buckets starting from the most similar (index -1)
        t_all_buckets = 0
        t_all_pairwise = 0
        t_all_sort = 0
        t_all_pure_pairwise = 0
        t_comp_threshold = 0
        for bucket in range(n_buckets):
            if bucket != 0 and use_threshold:
                s_ = time.time()
                threshold_dist = dists_final.max(axis=1)
                t_comp_threshold += time.time() - s_
            else:
                threshold_dist = None
            dists, anns, t_all, t_pairwise, t_pure_pairwise, t_sort = self.search_single(
                data_navigation,
                data_search,
                queries_search,
                pred_proba_categories[:, bucket],
                threshold_dist=threshold_dist
            )
            t_all_buckets += t_all
            t_all_pairwise += t_pairwise
            t_all_pure_pairwise += t_pure_pairwise
            t_all_sort += t_sort
            if anns_final is None:
                anns_final = anns
                dists_final = dists
            else:
                # stacks the results from the previous sorted anns and dists
                # *_final arrays now have shape (queries.shape[0], k*2)
                anns_final = np.hstack((anns_final, anns))
                dists_final = np.hstack((dists_final, dists))
                # gets the sorted indices of the stacked dists
                idx_sorted = dists_final.argsort(kind='stable', axis=1)[:, :k]
                # indexes the final arrays with the sorted indices
                # *_final arrays now have shape (queries.shape[0], k)
                idx = np.ogrid[tuple(map(slice, dists_final.shape))]
                idx[1] = idx_sorted
                dists_final = dists_final[tuple(idx)]
                anns_final = anns_final[tuple(idx)]

                assert anns_final.shape == dists_final.shape == (queries_search.shape[0], k)

        self.logger.info(f't_comp_threshold: {t_comp_threshold}')
        return dists_final, anns_final, time.time() - s, t_inference, t_all_buckets, t_all_pairwise, t_all_pure_pairwise, t_all_sort

    def search_single(
        self,
        data_navigation,
        data_search,
        queries_search,
        pred_categories,
        k=10,
        threshold_dist=None
    ):
        """ Search for k nearest neighbors for each query in queries.

        Parameters
        ----------
        queries : np.array
            Queries to search for.
        data : np.array
            Data to search in.
        k : int
            Number of nearest neighbors to search for.

        Returns
        -------
        anns : np.array
            Array of shape (queries.shape[0], k) with nearest neighbors for each query.
        final_dists_k : np.array
            Array of shape (queries.shape[0], k) with distances to nearest neighbors for each query.
        time : float
            Time it took to search.
        """
        s_all = time.time()
        nns = np.zeros((queries_search.shape[0], k), dtype=np.uint32)
        dists = np.zeros((queries_search.shape[0], k), dtype=np.float32)

        if 'category' in data_search.columns:
            data_search = data_search.drop('category', axis=1, errors='ignore')

        t_pairwise = 0
        t_pure_pairwise = 0
        t_sort = 0
        for cat, g in tqdm(data_navigation.groupby('category')):
            cat_idxs = np.where(pred_categories == cat)[0]
            bucket_obj_indexes = g.index
            if bucket_obj_indexes.shape[0] != 0 and cat_idxs.shape[0] != 0:
                s = time.time()
                # TODO: Add filter, filter will be different for every query
                # OR pass nns, dists from previous buckets
                if threshold_dist is not None:
                    seq_search_dists = pairwise_cosine_threshold(
                        queries_search[cat_idxs],
                        data_search.loc[bucket_obj_indexes],
                        threshold_dist,
                        cat_idxs,
                        k
                    )
                    if seq_search_dists[0] is None:
                        t_pure_pairwise += seq_search_dists[1]
                        # There is no distance below the threshold, we can continue
                        continue
                    else:
                        # seq_search_dists[1] contains the indexes of the relevant objects
                        bucket_obj_indexes = bucket_obj_indexes[seq_search_dists[1]]
                        t_pure_pairwise += seq_search_dists[2]
                        seq_search_dists = seq_search_dists[0]
                else:
                    s_ = time.time()
                    seq_search_dists = pairwise_cosine(
                        queries_search[cat_idxs],
                        data_search.loc[bucket_obj_indexes]
                    )
                    t_pure_pairwise += time.time() - s_
                t_pairwise += time.time() - s
                s = time.time()
                ann_relative = seq_search_dists.argsort(kind='quicksort')[
                    :, :k if k < seq_search_dists.shape[1] else seq_search_dists.shape[1]
                ]
                t_sort += time.time() - s
                if bucket_obj_indexes.shape[0] < k:
                    # pad to `k` if needed
                    pad_needed = (k - bucket_obj_indexes.shape[0]) // 2 + 1
                    bucket_obj_indexes = np.pad(np.array(bucket_obj_indexes), pad_needed, 'edge')[:k]
                    ann_relative = np.pad(ann_relative[0], pad_needed, 'edge')[:k].reshape(1, -1)
                    seq_search_dists = np.pad(seq_search_dists[0], pad_needed, 'edge')[:k].reshape(1, -1)
                    _, i = np.unique(seq_search_dists, return_index=True)
                    duplicates_i = np.setdiff1d(np.arange(k), i)
                    # assign a large number such that the duplicated value gets replaced
                    seq_search_dists[0][duplicates_i] = 10_000

                nns[cat_idxs] = np.array(bucket_obj_indexes)[ann_relative]
                dists[cat_idxs] = np.take_along_axis(seq_search_dists, ann_relative, axis=1)
        t_all = time.time() - s_all
        return dists, nns, t_all, t_pairwise, t_pure_pairwise, t_sort


    def build(self, data, n_categories=100, epochs=100, lr=0.1, model_type='MLP'):
        """ Build the index.

        Parameters
        ----------
        data : np.array
            Data to build the index on.

        Returns
        -------
        time : float
            Time it took to build the index.
        """
        s = time.time()
        # ---- cluster the data into categories ---- #
        _, labels = self.cluster(data, n_categories)

        # ---- train a neural network ---- #
        dataset = LIDataset(data, labels)
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=256,
            sampler=torch.utils.data.SubsetRandomSampler(
                data.index.values.tolist()
            )
        )
        nn = NeuralNetwork(
            input_dim=data.shape[1],
            output_dim=n_categories,
            lr=lr,
            model_type=model_type
        )
        nn.train_batch(train_loader, epochs=epochs, logger=self.logger)
        # ---- collect predictions ---- #
        self.model = nn
        return nn.predict(data_X_to_torch(data)), time.time() - s

    def cluster(
        self,
        data,
        n_clusters,
        n_redo=10,
        spherical=True,
        int_centroids=True,

    ):
        if data.shape[0] < 2:
            return None, np.zeros_like(data.shape[0])

        if data.shape[0] < n_clusters:
            n_clusters = data.shape[0] // 5
            if n_clusters < 2:
                n_clusters = 2

        kmeans = faiss.Kmeans(
            d=np.array(data).shape[1],
            k=n_clusters,
            verbose=True,
            #nredo=n_redo,
            #spherical=spherical,
            #int_centroids=int_centroids,
            #update_index=False,
            seed=2023
        )
        X = np.array(data).astype(np.float32)
        kmeans.train(X)

        return kmeans, kmeans.index.search(X, 1)[1].T[0]


[2023-07-14 15:11:32,229][INFO ][faiss.loader] Loading faiss with AVX2 support.
[2023-07-14 15:11:32,232][INFO ][faiss.loader] Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'",)
[2023-07-14 15:11:32,234][INFO ][faiss.loader] Loading faiss.
[2023-07-14 15:11:32,265][INFO ][faiss.loader] Successfully loaded faiss.


In [11]:
data = pd.DataFrame(data)
data.index += 1

data_search = pd.DataFrame(data_search)
data_search.index += 1

In [29]:
li = LearnedIndex()
pred_categories, build_t = li.build(
    data,
    n_categories=200,
    epochs=200,
    lr=0.1
)

[2023-07-14 15:06:06,034][INFO ][__main__.LearnedInde] Epochs: 200, step: 20
[2023-07-14 15:06:29,323][INFO ][__main__.LearnedInde] Epoch 20 | Loss 1.98601
[2023-07-14 15:06:51,398][INFO ][__main__.LearnedInde] Epoch 40 | Loss 1.18380
[2023-07-14 15:07:14,368][INFO ][__main__.LearnedInde] Epoch 60 | Loss 1.05507
[2023-07-14 15:07:36,406][INFO ][__main__.LearnedInde] Epoch 80 | Loss 1.10324
[2023-07-14 15:07:58,435][INFO ][__main__.LearnedInde] Epoch 100 | Loss 1.13756
[2023-07-14 15:08:21,092][INFO ][__main__.LearnedInde] Epoch 120 | Loss 0.89168
[2023-07-14 15:08:43,176][INFO ][__main__.LearnedInde] Epoch 140 | Loss 1.07258
[2023-07-14 15:09:05,576][INFO ][__main__.LearnedInde] Epoch 160 | Loss 0.81745
[2023-07-14 15:09:27,546][INFO ][__main__.LearnedInde] Epoch 180 | Loss 0.70368


In [12]:
LOG.info(f'Loading GT')
gt_path = f'data/groundtruth-{size}.h5'
f3 = h5py.File(gt_path, 'r')
loaded_gt = f3['knns'][:, :]

[2023-07-14 15:11:36,772][INFO ][__main__] Loading GT


In [31]:
k=10
bucket=4

In [34]:
dists, nns, search_t, inference_t, search_single_t, seq_search_t, pure_seq_search_t, sort_t = li.search(
    data_navigation=data,
    queries_navigation=queries,
    data_search=data_search,
    queries_search=queries_search,
    pred_categories=pred_categories,
    n_buckets=bucket,
    k=k,
    use_threshold=True
)
LOG.info('Inference time: %s', inference_t)
LOG.info('Search time: %s', search_t)
LOG.info('Search single time: %s', search_single_t)
LOG.info('Sequential search time: %s', seq_search_t)
LOG.info('Pure sequential search time: %s', pure_seq_search_t)
LOG.info('Sort time: %s', sort_t)

100%|██████████| 200/200 [00:01<00:00, 135.27it/s]
100%|██████████| 200/200 [00:01<00:00, 151.76it/s]
100%|██████████| 200/200 [00:01<00:00, 166.24it/s]
100%|██████████| 200/200 [00:01<00:00, 166.79it/s]
[2023-07-14 15:10:20,767][INFO ][__main__.LearnedInde] t_comp_threshold: 0.0014171600341796875
[2023-07-14 15:10:20,770][INFO ][__main__] Inference time: 0.16141343116760254
[2023-07-14 15:10:20,773][INFO ][__main__] Search time: 5.416271448135376
[2023-07-14 15:10:20,776][INFO ][__main__] Search single time: 5.240442276000977
[2023-07-14 15:10:20,778][INFO ][__main__] Sequential search time: 4.479606866836548
[2023-07-14 15:10:20,781][INFO ][__main__] Pure sequential search time: 3.7803361415863037
[2023-07-14 15:10:20,783][INFO ][__main__] Sort time: 0.373150110244751


In [35]:
def get_recall(I, gt, k):
    assert k <= I.shape[1]
    assert len(I) == len(gt)

    n = len(I)
    recall = 0
    for i in range(n):
        recall += len(set(I[i, :k]) & set(gt[i, :k]))
    return recall / (n * k)


In [36]:
recall = get_recall(nns, loaded_gt, 10)


In [37]:
recall

0.70667

In [None]:
li = LearnedIndex()
pred_categories, build_t = li.build(
    data,
    n_categories=100,
    epochs=200,
    lr=0.1
)

[2023-07-14 15:11:44,421][INFO ][__main__.LearnedInde] Epochs: 200, step: 20
[2023-07-14 15:12:07,214][INFO ][__main__.LearnedInde] Epoch 20 | Loss 1.23865
[2023-07-14 15:12:28,505][INFO ][__main__.LearnedInde] Epoch 40 | Loss 0.85032


In [None]:
k=10
bucket=8

In [None]:
dists, nns, search_t, inference_t, search_single_t, seq_search_t, pure_seq_search_t, sort_t = li.search(
    data_navigation=data,
    queries_navigation=queries,
    data_search=data_search,
    queries_search=queries_search,
    pred_categories=pred_categories,
    n_buckets=bucket,
    k=k,
    use_threshold=True
)
LOG.info('Inference time: %s', inference_t)
LOG.info('Search time: %s', search_t)
LOG.info('Search single time: %s', search_single_t)
LOG.info('Sequential search time: %s', seq_search_t)
LOG.info('Pure sequential search time: %s', pure_seq_search_t)
LOG.info('Sort time: %s', sort_t)

In [None]:
recall = get_recall(nns, loaded_gt, 10)
recall