In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import h5py
import pandas as pd
import pickle
from tqdm import tqdm
from li.utils import pairwise_cosine
import time
import logging
import numpy as np
import os
from scipy import sparse


In [3]:
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(levelname)-5.5s][%(name)-.20s] %(message)s'
)
LOG = logging.getLogger(__name__)

def increase_max_recursion_limit():
    """ Increases the maximum recursion limit.
    Source: https://stackoverflow.com/a/16248113
    """
    import sys
    import resource
    resource.setrlimit(resource.RLIMIT_STACK, (2**29, -1))
    sys.setrecursionlimit(10**6)


In [19]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from li.Logger import Logger
from li.utils import pairwise_cosine
import time
from sklearn.metrics import accuracy_score
import torch
import torch.utils.data
#from li.model import NeuralNetwork, data_X_to_torch, LIDataset
import faiss

import torch
from torch import nn
import torch.nn.functional as nnf
from dlmi.utils import get_device, reverse_dict
import numpy as np
from dlmi.Logger import Logger
from typing import List, Tuple
import torch
import torch.utils.data


class Model(nn.Module):
    def __init__(self, input_dim=768, output_dim=1000, model_type=None):
        super().__init__()
        if model_type == 'MLP':
            self.layers = torch.nn.Sequential(
            torch.nn.Linear(input_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, output_dim)
            )
        if model_type == 'Bigger':
            self.layers = torch.nn.Sequential(
            torch.nn.Linear(input_dim, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, output_dim)
            )
        self.n_output_neurons = output_dim

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        outputs = self.layers(x)
        return outputs


def get_device() -> torch.device:
    """ Gets the `device` to be used by torch.
    This arugment is needed to operate with the PyTorch model instance.

    Returns
    ------
    torch.device
        Device
    """
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    torch.backends.cudnn.benchmark = True
    return device


def data_X_to_torch(data) -> torch.FloatTensor:
    """ Creates torch training data."""
    data_X = torch.from_numpy(np.array(data).astype(np.float32))
    return data_X


def data_to_torch(data, labels) -> Tuple[torch.FloatTensor, torch.LongTensor]:
    """ Creates torch training data and labels."""
    data_X = data_X_to_torch(data)
    data_y = torch.as_tensor(torch.from_numpy(labels), dtype=torch.long)
    return data_X, data_y


class NeuralNetwork(Logger):
    """ The neural network class corresponding to every inner node.

    Parameters
    ----------
    input_dim : int
        The input dimension.
    output_dim : int
        The output dimension.
    loss : torch.nn, optional
        The loss function, the default is torch.nn.CrossEntropyLoss.
    lr : float, optional
        The learning rate, the default is 0.001.
    model_type : str, optional
        The model type, the default is 'MLP'.
    class_weight : torch.FloatTensor, optional
        The class weights, the default is None.
    """
    def __init__(
        self,
        input_dim,
        output_dim,
        loss=torch.nn.CrossEntropyLoss,
        lr=0.1,
        model_type='MLP',
        class_weight=None
    ):
        self.device = get_device()
        self.model = Model(input_dim, output_dim, model_type=model_type).to(self.device)
        if not isinstance(class_weight, type(None)):
            self.loss = loss(weight=class_weight.to(self.device))
        else:
            self.loss = loss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def train(
        self,
        data_X: torch.FloatTensor,
        data_y: torch.LongTensor,
        epochs=500,
        logger=None
    ):
        #logger.debug(f'Epochs: {epochs}')
        step = epochs // 10
        losses = []
        if logger:
            logger.info(f'Epochs: {epochs}, step: {step}')
        for ep in range(epochs):
            pred_y = self.model(data_X.to(self.device))
            curr_loss = self.loss(pred_y, data_y.to(self.device))
            if ep % step == 0 and ep != 0:
                if logger:
                    logger.info(f'Epoch {ep} | Loss {curr_loss.item()}')
            losses.append(curr_loss.item())

            self.model.zero_grad()
            curr_loss.backward()

            self.optimizer.step()
        return losses

    def train_batch(
        self,
        dataset,
        epochs=5,
        logger=None
    ):
        #logger.debug(f'Epochs: {epochs}')
        step = epochs // 10
        step = step if step > 0 else 1
        losses = []
        if logger:
            logger.info(f'Epochs: {epochs}, step: {step}')
        for ep in range(epochs):
            for data_X, data_y in iter(dataset):
                pred_y = self.model(data_X.to(self.device))
                curr_loss = self.loss(pred_y, data_y.to(self.device))

            if ep % step == 0 and ep != 0:
                if logger:
                    logger.info(f'Epoch {ep} | Loss {curr_loss.item():.5f}')
            losses.append(curr_loss.item())

            self.model.zero_grad()
            curr_loss.backward()

            self.optimizer.step()
        return losses

    def predict(self, data_X: torch.FloatTensor):
        """ Collects predictions for multiple data points (used in structure building)."""
        self.model = self.model.to(self.device)
        self.model.eval()

        all_outputs = torch.tensor([], device=self.device)
        with torch.no_grad():
            outputs = self.model(data_X.to(self.device))
            all_outputs = torch.cat((all_outputs, outputs), 0)

        _, y_pred = torch.max(all_outputs, 1)
        return y_pred.cpu().numpy()

    def predict_proba(self, data_X: torch.FloatTensor):
        """ Collects predictions for a single data point (used in query predictions)."""
        self.model = self.model.to(self.device)
        self.model.eval()

        with torch.no_grad():
            outputs = self.model(data_X.to(self.device))

        if outputs.dim() == 1:
            dim=0
        else:
            dim=1
        prob = nnf.softmax(outputs, dim=dim)

        #if prob.dim() == 1:
        #    probs, classes = prob.topk(prob.shape[0])
        #else:
        probs, classes = prob.topk(prob.shape[1])

        return probs.cpu().numpy(), classes.cpu().numpy()


class LIDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_x, dataset_y):
        self.dataset_x, self.dataset_y = data_to_torch(dataset_x, dataset_y)

    def __len__(self):
        return self.dataset_x.shape[0]
    
    def __getitem__(self, idx):
        return self.dataset_x[idx-1], self.dataset_y[idx-1]




class LearnedIndex(Logger):

    def __init__(self):
        self.pq = []
        self.model = None

    def search(
        self,
        data_navigation,
        queries_navigation,
        data_search,
        queries_search,
        pred_categories,
        n_buckets=1,
        k=10
    ):
        """ Search for k nearest neighbors for each query in queries.

        Parameters
        ----------
        queries : np.array
            Queries to search for.
        data : np.array
            Data to search in.
        n_buckets : int
            Number of most similar buckets to search in.
        k : int
            Number of nearest neighbors to search for.

        Returns
        -------
        dists : np.array
            Array of shape (queries.shape[0], k) with distances to nearest neighbors for each query.
        anns : np.array
            Array of shape (queries.shape[0], k) with nearest neighbors for each query.
        time : float
            Time it took to search.
        """
        assert self.model is not None, 'Model is not trained, call `build` first.'
        s = time.time()
        _, pred_proba_categories = self.model.predict_proba(
            data_X_to_torch(queries_navigation)
        )
        anns_final = None
        dists_final = None
        # sorts the predictions of a bucket for each query, ordered by lowest probability
        data_navigation['category'] = pred_categories

        # iterates over the predicted buckets starting from the most similar (index -1)
        for bucket in range(n_buckets):
            dists, anns = self.search_single(
                data_navigation,
                data_search,
                queries_search,
                pred_proba_categories[:, bucket]
            )
            if anns_final is None:
                anns_final = anns
                dists_final = dists
            else:
                # stacks the results from the previous sorted anns and dists
                # *_final arrays now have shape (queries.shape[0], k*2)
                anns_final = np.hstack((anns_final, anns))
                dists_final = np.hstack((dists_final, dists))
                # gets the sorted indices of the stacked dists
                idx_sorted = dists_final.argsort(kind='stable', axis=1)[:, :k]
                # indexes the final arrays with the sorted indices
                # *_final arrays now have shape (queries.shape[0], k)
                idx = np.ogrid[tuple(map(slice, dists_final.shape))]
                idx[1] = idx_sorted
                dists_final = dists_final[tuple(idx)]
                anns_final = anns_final[tuple(idx)]

                assert anns_final.shape == dists_final.shape == (queries_search.shape[0], k)

        return dists_final, anns_final, time.time() - s

    def search_single(
        self,
        data_navigation,
        data_search,
        queries_search,
        pred_categories,
        k=10
    ):
        """ Search for k nearest neighbors for each query in queries.

        Parameters
        ----------
        queries : np.array
            Queries to search for.
        data : np.array
            Data to search in.
        k : int
            Number of nearest neighbors to search for.

        Returns
        -------
        anns : np.array
            Array of shape (queries.shape[0], k) with nearest neighbors for each query.
        final_dists_k : np.array
            Array of shape (queries.shape[0], k) with distances to nearest neighbors for each query.
        time : float
            Time it took to search.
        """
        nns = np.zeros((queries_search.shape[0], k), dtype=np.uint32)
        dists = np.zeros((queries_search.shape[0], k), dtype=np.float32)

        for cat in np.unique(pred_categories):
            cat_idxs = np.where(pred_categories == cat)[0]
            bucket_obj_indexes = data_navigation.query('category == @cat').index
            #if bucket_obj_indexes.shape[0] != 0:
            seq_search_dists = pairwise_cosine(
                queries_search[cat_idxs], data_search.loc[bucket_obj_indexes]
            )
            ann_relative = seq_search_dists.argsort()[:, :k]
            nns[cat_idxs] = np.array(bucket_obj_indexes)[ann_relative]
            dists[cat_idxs] = np.take_along_axis(seq_search_dists, ann_relative, axis=1)

        return dists, nns

    def build(self, data, n_categories=100, epochs=100, lr=0.1):
        """ Build the index.

        Parameters
        ----------
        data : np.array
            Data to build the index on.

        Returns
        -------
        time : float
            Time it took to build the index.
        """
        s = time.time()
        # ---- cluster the data into categories ---- #
        _, labels = self.cluster(data, n_categories)

        # ---- train a neural network ---- #
        dataset = LIDataset(data, labels)
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=256,
            sampler=torch.utils.data.SubsetRandomSampler(
                data.index.values.tolist()
            )
        )
        nn = NeuralNetwork(
            input_dim=data.shape[1],
            output_dim=n_categories,
            lr=lr,
            model_type='MLP'
        )
        nn.train_batch(train_loader, epochs=epochs, logger=self.logger)
        # ---- collect predictions ---- #
        self.model = nn
        return nn.predict(data_X_to_torch(data)), time.time() - s

    def cluster(self, data, n_clusters):
        if data.shape[0] < 2:
            return None, np.zeros_like(data.shape[0])

        if data.shape[0] < n_clusters:
            n_clusters = data.shape[0] // 5
            if n_clusters < 2:
                n_clusters = 2

        kmeans = faiss.Kmeans(d=np.array(data).shape[1], k=n_clusters)
        X = np.array(data).astype(np.float32)
        kmeans.train(X)

        return kmeans, kmeans.index.search(X, 1)[1].T[0]

In [4]:
size = '100K'

LOG.info(f'Loading pca32 data')
data_path = f'data/pca32v2/{size}/dataset.h5'
f = h5py.File(data_path, 'r')
loaded_data = f['pca32'][:, :]
data = pd.DataFrame(loaded_data)
data.index += 1

LOG.info(f'Loading queries')
base_path = f'data/pca32v2/{size}/'
queries_path = f'{base_path}/query.h5'
f2 = h5py.File(queries_path, 'r')
#loaded_queries = f2['emb'][:, :]
loaded_queries = f2['pca32'][:, :]

base_path = f'data/clip768v2/{size}/'
queries_path = f'{base_path}/query.h5'
f2 = h5py.File(queries_path, 'r')
#loaded_queries = f2['emb'][:, :]
loaded_queries_seq = f2['emb'][:, :]

LOG.info(f'Loading clip data')
data_path = f'data/clip768v2/{size}/dataset.h5'
f = h5py.File(data_path, 'r')
loaded_clip_data = f['emb'][:, :]
loaded_clip_data = pd.DataFrame(loaded_clip_data)
loaded_clip_data.index += 1

LOG.info(f'Loading GT')
gt_path = f'data/groundtruth-{size}.h5'
f3 = h5py.File(gt_path, 'r')
loaded_gt = f3['knns'][:, :]


[2023-07-06 17:01:43,945][INFO ][__main__] Loading pca32 data
[2023-07-06 17:01:44,164][INFO ][__main__] Loading queries
[2023-07-06 17:01:44,626][INFO ][__main__] Loading clip data
[2023-07-06 17:01:46,290][INFO ][__main__] Loading GT


In [13]:
n_categories=100
epochs=100
lr=0.1

In [10]:
import os
os.path.abspath("")

'/auto/brno12-cerit/nfs4/home/tslaninakova/sisap-challenge-clean'

In [20]:
li = LearnedIndex()
# ---- build the index ---- #
pred_categories, build_t = li.build(
    data,
    n_categories=n_categories,
    epochs=epochs,
    lr=lr
)
LOG.info(f'Pure build time: {build_t}')

[2023-07-06 17:06:45,386][INFO ][__main__.LearnedInde] Epochs: 100, step: 10
[2023-07-06 17:06:58,380][INFO ][__main__.LearnedInde] Epoch 10 | Loss 2.12920
[2023-07-06 17:07:10,147][INFO ][__main__.LearnedInde] Epoch 20 | Loss 1.11942
[2023-07-06 17:07:21,924][INFO ][__main__.LearnedInde] Epoch 30 | Loss 1.08467
[2023-07-06 17:07:33,673][INFO ][__main__.LearnedInde] Epoch 40 | Loss 0.79363
[2023-07-06 17:07:45,422][INFO ][__main__.LearnedInde] Epoch 50 | Loss 0.84400
[2023-07-06 17:07:57,226][INFO ][__main__.LearnedInde] Epoch 60 | Loss 0.62600
[2023-07-06 17:08:08,967][INFO ][__main__.LearnedInde] Epoch 70 | Loss 0.84457
[2023-07-06 17:08:20,714][INFO ][__main__.LearnedInde] Epoch 80 | Loss 0.74277
[2023-07-06 17:08:32,440][INFO ][__main__.LearnedInde] Epoch 90 | Loss 0.72751
[2023-07-06 17:08:43,202][INFO ][__main__] Pure build time: 118.17009544372559


In [31]:
probs, pred_proba_categories = li.model.predict_proba(
    data_X_to_torch(loaded_queries)
)
anns_final = None
dists_final = None
# sorts the predictions of a bucket for each query, ordered by lowest probability
data['category'] = pred_categories

In [24]:
pred_proba_categories.shape

(10000, 100)

In [25]:
dists, anns = li.search_single(
    data,
    loaded_clip_data,
    loaded_queries_seq,
    pred_proba_categories[:, 0]
)

[2023-07-06 17:09:51,809][INFO ][numexpr.utils] Note: detected 128 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2023-07-06 17:09:51,810][INFO ][numexpr.utils] Note: NumExpr detected 128 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.


In [26]:
anns.shape

(10000, 10)

In [34]:
pred_proba_categories[0]

array([47, 56, 51, 60, 53, 35, 55, 41, 32, 89, 71, 80, 70, 85, 68, 49, 84,
       86, 30, 23, 45, 83, 82, 92, 87,  5, 69, 43, 38, 36, 99, 29, 27, 75,
       26, 34, 94,  2, 95, 25, 48, 97, 21, 46, 14, 31,  8, 20, 73, 24, 78,
       57,  0, 44, 59,  6, 88, 62, 19, 81, 98, 22, 93, 58, 77, 96, 40, 33,
        3, 39,  9, 67, 74,  1, 11, 15, 18, 52, 42, 28, 37, 16, 50, 91,  4,
       72, 63, 66, 76, 61, 54, 90, 12, 64, 79, 13, 65, 10,  7, 17])

In [48]:
data.loc[92811].category

47.0

In [27]:
pred_proba_categories[:, 0]

array([47, 54, 59, ..., 96, 32, 75])

In [29]:
loaded_gt[0][:10]

array([79172, 15735, 22337,   231, 74173, 41079, 38159, 71849, 69015,
       92811], dtype=int32)

In [30]:
anns[0]

array([15735, 92811, 87457, 42586, 70297,  5451, 21472, 87037,  6911,
       45784], dtype=uint32)

In [49]:
dists2, anns2 = li.search_single(
    data,
    loaded_clip_data,
    loaded_queries_seq,
    pred_proba_categories[:, 1]
)

In [50]:
anns_final = anns
dists_final = dists

In [51]:
anns_final = np.hstack((anns_final, anns2))
dists_final = np.hstack((dists_final, dists2))

In [53]:
k=10
idx_sorted = dists_final.argsort(kind='stable', axis=1)[:, :k]
idx_sorted[0]

array([10,  0, 11, 12, 13, 14, 15,  1, 16, 17])

In [54]:
idx = np.ogrid[tuple(map(slice, dists_final.shape))]
idx[1] = idx_sorted
dists_final = dists_final[tuple(idx)]
anns_final = anns_final[tuple(idx)]

In [55]:
anns_final[0]

array([79172, 15735, 22337, 74173, 41079, 38159, 69015, 92811, 99973,
       79896], dtype=uint32)

In [56]:
n_buckets=99

In [58]:
data_navigation = data
data_search = loaded_clip_data
queries_search = loaded_queries_seq

In [59]:
%%time
anns_final = None
dists_final = None
for bucket in range(n_buckets):
    dists, anns = li.search_single(
        data_navigation,
        data_search,
        queries_search,
        pred_proba_categories[:, bucket]
    )
    if anns_final is None:
        anns_final = anns
        dists_final = dists
    else:
        # stacks the results from the previous sorted anns and dists
        # *_final arrays now have shape (queries.shape[0], k*2)
        anns_final = np.hstack((anns_final, anns))
        dists_final = np.hstack((dists_final, dists))
        # gets the sorted indices of the stacked dists
        idx_sorted = dists_final.argsort(kind='stable', axis=1)[:, :k]
        # indexes the final arrays with the sorted indices
        # *_final arrays now have shape (queries.shape[0], k)
        idx = np.ogrid[tuple(map(slice, dists_final.shape))]
        idx[1] = idx_sorted
        dists_final = dists_final[tuple(idx)]
        anns_final = anns_final[tuple(idx)]

        assert anns_final.shape == dists_final.shape == (queries_search.shape[0], k)


CPU times: user 2min 53s, sys: 366 ms, total: 2min 53s
Wall time: 2min 57s


In [63]:
overlaps = []
for i in range(10_000):
    overlaps.append(np.intersect1d(anns_final[i], loaded_gt[i]).shape[0])

In [64]:
np.mean(overlaps)

9.9936

In [None]:
anns_final[0]

In [62]:
anns_final[0] == loaded_gt[0][:10]

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])