In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import h5py
data_path = 'laion2B-en-clip768v2-n=100K.h5'
%time f = h5py.File(data_path, 'r')

CPU times: user 1.16 ms, sys: 266 µs, total: 1.43 ms
Wall time: 63 ms


In [3]:
import pandas as pd
%time data = pd.DataFrame(list(f['emb']))

CPU times: user 35.9 s, sys: 2.14 s, total: 38.1 s
Wall time: 45.1 s


In [4]:
data.head(2)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,-0.00198,0.043152,-0.012474,-0.042297,0.037201,0.034027,0.007248,0.026535,-0.022705,-0.038544,...,-0.032562,-0.043121,0.015541,-0.006378,-0.051178,0.016327,-0.020538,-0.008286,-0.009285,-0.024933
1,0.057281,0.052429,0.042633,0.050262,0.001604,0.0186,-0.00032,-0.013687,-0.013916,0.02356,...,0.042267,0.015373,-0.006653,0.018616,-0.061127,0.033722,-0.000931,0.006901,0.01886,-0.008125


In [5]:
q_path = 'public-queries-10k-clip768v2.h5'
f = h5py.File(q_path, 'r')
queries = pd.DataFrame(list(f['emb']))

In [6]:
gt_path = 'laion2B-en-public-gold-standard-v2-100K.h5'
f = h5py.File(gt_path, 'r')
gt_knns = pd.DataFrame(list(f['knns']))

In [7]:
from sklearn.metrics.pairwise import cosine_similarity
def pairwise_cosine(x, y):
    return 1-cosine_similarity(x, y)

In [8]:
# how long does it take to brute-force search for most similar datapoints
%time res = pairwise_cosine([data.iloc[0]], data)

CPU times: user 482 ms, sys: 209 ms, total: 691 ms
Wall time: 707 ms


### training approach #1
- pick 100 random objects, compute their knns, pick 1% (1k) closest, train on them, resolve overlaps by picking the closest category for an object
- Q: does it scale?

In [10]:
import numpy as np

In [12]:
%%time

idxs = res[0].argsort()[1:1001]
# array of 1) similar obj ids, 2) their distances 3) should-keep bitmap
cat = np.vstack([idxs, res[0][idxs], np.ones(idxs.shape[0])])

CPU times: user 20.5 ms, sys: 27 µs, total: 20.5 ms
Wall time: 20.6 ms


In [13]:
%%time

res2 = cosine_similarity([data.iloc[1]], data)
idxs2 = res2[0].argsort()[1:1001]
cat2 = np.vstack([idxs2, res2[0][idxs2], np.ones(idxs.shape[0])])

CPU times: user 468 ms, sys: 192 ms, total: 660 ms
Wall time: 684 ms


In [68]:
def resolve_overlaps(categories):
    for cat in categories[:1]:
        for next_cat in categories[1:]:
            overlap_objs = np.intersect1d(cat[0], next_cat[0], return_indices=True)
            if overlap_objs[0].size != 0:
                cat, next_cat = resolve_overlap_objects(cat, next_cat, overlap_objs)

In [93]:
cat

array([[5.75380000e+04, 7.11530000e+04, 9.22260000e+04, ...,
        7.47800000e+04, 5.78570000e+04, 5.03650000e+04],
       [1.46781577e-01, 1.50667473e-01, 1.59918897e-01, ...,
        3.33627767e-01, 3.33633559e-01, 3.33656429e-01],
       [1.00000000e+00, 1.00000000e+00, 1.00000000e+00, ...,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00]])

In [92]:
overlap_objs

(array([], dtype=float64), array([], dtype=int64), array([], dtype=int64))

In [100]:
overlap_objs[0].size == 0

True

In [65]:
%%time
for prev_cat in [cat]:
    overlap_objs = np.intersect1d(prev_cat[0], cat2[0], return_indices=True)
    if overlap_objs[0].size != 0:
        cat, cat2 = resolve_overlap_objects(cat, cat2, overlap_objs)

CPU times: user 406 µs, sys: 0 ns, total: 406 µs
Wall time: 412 µs


In [66]:
import logging
logging.basicConfig(level=logging.INFO)
logging.info('Initialized logger')

INFO:root:Initialized logger


In [204]:
def create_category(data, random_state_offset):
    main_datapoint = data.sample(1, random_state=2023+random_state_offset)
    #logging.debug(f'using object: {main_datapoint.index[0]}')
    distances = pairwise_cosine(main_datapoint.values, data)
    idxs = distances[0].argsort()[1:1001]
    return (main_datapoint.index[0], np.vstack([idxs, distances[0][idxs]]))


In [205]:
# Not used, dealing with overlaps with pandas
def resolve_overlap_objects(cat1, cat2, overlap_objs):
    def resolve_single_overlap(cat1, cat2, o):
        if cat1[1][o[0]] > cat2[1][o[1]]:
            cat1[2][o[0]] = 0
        else:
            cat2[2][o[1]] = 0
        return cat1, cat2
    if len(overlap_objs[1:][0]) == 1:
        cat1, cat2 = resolve_single_overlap(cat1, cat2, overlap_objs[1:])
    else:
        for o in overlap_objs[1:]:
            cat1, cat2 = resolve_single_overlap(cat1, cat2, o)
    return cat1, cat2

In [206]:
from tqdm import tqdm

In [235]:
%%time

categories = []
main_objs = []
for i in tqdm(range(100)):
    obj_id, cat = create_category(data, random_state_offset=i)
    main_objs.append(obj_id)
    categories.append(cat)

100%|██████████| 100/100 [00:55<00:00,  1.82it/s]

CPU times: user 39 s, sys: 14.7 s, total: 53.7 s
Wall time: 55.3 s





In [236]:
%%time

# pd concat loop with adding extra column for cat
df_all = pd.DataFrame(np.empty(0, dtype=np.uint32))
for cat_id, c in enumerate(categories):
    df_ = pd.DataFrame(c.T)
    df_[2] = cat_id
    df_all = pd.concat([df_all, df_])

df_all = df_all.rename(
    columns={
        0: 'object_id', 1: 'dist', 2: 'category_id'
    }
).sort_values(
    'dist', ascending=True
).drop_duplicates(
    'object_id', keep='first'
)

CPU times: user 319 ms, sys: 0 ns, total: 319 ms
Wall time: 320 ms


In [238]:
df_all.category_id.value_counts()

28.0    886
11.0    846
51.0    811
56.0    807
93.0    800
       ... 
12.0    198
29.0    189
4.0     173
37.0    129
78.0    114
Name: category_id, Length: 100, dtype: int64

### Time to create training labels for 100 categories (100k data): ~ 1min

### Train without train-test split 

In [243]:
%%time

from sklearn.linear_model import LogisticRegression
X = data.loc[df_all.object_id.astype(int)]
y = df_all.category_id.astype(int)

clf = LogisticRegression(random_state=2023, max_iter=500).fit(X, y)

CPU times: user 1min 15s, sys: 242 ms, total: 1min 15s
Wall time: 1min 16s


In [261]:
%time res = clf.predict_proba(X.iloc[:1])

CPU times: user 11.4 ms, sys: 0 ns, total: 11.4 ms
Wall time: 14.8 ms


In [263]:
%time res = clf.predict(X)

CPU times: user 395 ms, sys: 0 ns, total: 395 ms
Wall time: 397 ms


In [288]:
%time data_categories = pd.DataFrame(clf.predict(data), index=data.index, columns=['category'])

CPU times: user 718 ms, sys: 0 ns, total: 718 ms
Wall time: 744 ms


In [289]:
data_categories

Unnamed: 0,category
0,1
1,85
2,10
3,8
4,50
...,...
99995,33
99996,22
99997,17
99998,50


In [264]:
clf.score(X, y)

0.8434859749915512

## evaluate time taken for 90% recall on gt, queries

In [327]:
sample_query = queries.iloc[[1]]

In [328]:
%time predicted_category = np.argmax(clf.predict_proba(sample_query))

CPU times: user 11.4 ms, sys: 0 ns, total: 11.4 ms
Wall time: 11.4 ms


In [329]:
%time bucket_obj_indexes = data_categories.query('category == @predicted_category').index

CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 4.01 ms


In [330]:
%%time

bucket_df = data.loc[bucket_obj_indexes]
final_dists = pairwise_cosine(sample_query, bucket_df)

CPU times: user 15.1 ms, sys: 0 ns, total: 15.1 ms
Wall time: 15.1 ms


In [331]:
%time final_gts = final_dists[0].argsort()[:10]

CPU times: user 155 µs, sys: 38 µs, total: 193 µs
Wall time: 197 µs


In [332]:
anns = bucket_df.iloc[final_gts].index + 1
anns

Int64Index([14347, 82848, 79302, 85923, 6016, 67067, 29567, 54566, 34591,
            11620],
           dtype='int64')

In [333]:
def evaluate_recall(anns, knns, k=10):
    n_hits = knns.intersection(anns).shape[0]
    return n_hits / k

In [334]:
%time evaluate_recall(anns, pd.Index(gt_knns.iloc[1, :10].values))

CPU times: user 666 µs, sys: 163 µs, total: 829 µs
Wall time: 834 µs


1.0

## all queries, single bucket

In [316]:
queries.shape[0]

10000

In [318]:
queries.iloc[[0]]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.001425,0.025776,-0.02312,-0.045599,0.006402,0.002812,0.016146,-0.039648,-0.028935,-0.024204,...,0.005059,-0.031346,-0.001856,-0.025669,0.008989,-0.005685,0.044714,-0.007112,-0.001487,0.015734


In [342]:
%%time

recalls = []

# stop cond: single bucket
# time to evaluate 10k queries:
# mean recall: 0.5
def search(query_idx, k=10):
    query = queries.iloc[[query_idx]]
    predicted_category = np.argmax(clf.predict_proba(query))
    bucket_obj_indexes = data_categories.query('category == @predicted_category').index
    bucket_df = data.loc[bucket_obj_indexes]
    final_dists = pairwise_cosine(query, bucket_df)[0]
    final_gts = final_dists.argsort()[:k]
    anns = bucket_df.iloc[final_gts].index + 1
    return evaluate_recall(anns, pd.Index(gt_knns.iloc[query_idx, :10].values))

for q_idx in tqdm(range(queries.shape[0])):
    recalls.append(search(q_idx))

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



In [343]:
np.mean(recalls)

0.47067847343477176