## ⚙ Install FAISS

### 🐌 CPU only

In [1]:
# !apt -q install libomp-dev
# !pip -q install faiss-cpu --no-cache

### 🚀 GPU (Can be used both on CPU or GPU) - **Recommended**



In [2]:
# !apt -q install libomp-dev
# !pip -q install faiss-gpu

## 🛠 Imports

In [3]:
import pandas as pd
import numpy as np
import faiss
from tqdm.notebook import tqdm
import requests
import os
from urllib.parse import urlencode
import zipfile
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', None)

## 📥 Download data to workdir

- call the downloader with **small** parameter for small version of dataset
- or with **large** parameter for full version


In [4]:
def downloader(size: str='small'):
    if size not in ['small', 'large']:
        raise Exception('Unknown Argument')
    elif size == 'small':
        public_key = 'https://disk.yandex.ru/d/YQElc_cNQQLSOw'
    else:
        public_key = 'https://disk.yandex.ru/d/BBEphK0EHSJ5Jw'

    base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'

    final_url = base_url + urlencode(dict(public_key=public_key))
    response = requests.get(final_url)
    download_url = response.json()['href']

    download_response = requests.get(download_url)
    with open('/content/data.zip', 'wb') as f:
        f.write(download_response.content)

    zip_path = ('/content/data.zip')

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall()

In [8]:
df_path = 'data/match/'

## 😵‍💫 First look

In [9]:
%%time
df_base = pd.read_csv(df_path+"base.csv", index_col=0)
df_base.head()

CPU times: user 3.07 s, sys: 189 ms, total: 3.25 s
Wall time: 3.25 s


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71
Id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1
4207931-base,-43.946243,15.364378,17.515854,-132.31146,157.06442,-4.069252,-340.63086,-57.55014,128.39822,45.090958,-126.84374,4.494522,-99.84231,44.926903,177.52173,-12.29179,38.47036,105.35765,-142.46024,-80.16326,-110.368935,1047.517357,-69.59462,66.31354,84.87387,813.770071,-81.03878,16.162964,-98.24488,159.53406,27.554913,-209.18428,62.05977,-529.295053,114.59833,90.469894,-20.256914,-164.768,-133.31387,-41.25296,-10.251193,8.289038,-131.31271,75.7045,-16.483078,40.771038,-146.09674,-143.40768,49.807987,63.43448,-30.25008,20.470263,78.07991,-128.91531,92.32768,63.88557,-141.17464,142.90259,-93.068596,-568.421584,-90.01869,-129.01567,-71.92717,30.711966,-90.190475,-24.931271,66.972534,106.346634,-44.270622,155.98834,-1074.464888,-25.066608
2710972-base,-73.00489,4.923342,-19.750746,-136.52908,99.90717,-70.70911,-567.401996,-128.89015,109.914986,201.4722,-186.2265,29.896042,-99.770996,0.126302,136.19049,-35.22474,-30.321323,-43.148834,-162.85175,-79.71451,-75.78487,1507.231274,-69.654564,43.640663,-4.779669,813.770071,43.976913,11.924875,-50.228523,166.0082,-59.505333,-115.33252,72.18324,-735.671365,96.3223,85.79636,-22.03033,-147.54501,-108.38295,-45.084892,-15.004004,-1.532826,-46.456585,197.57895,-56.199876,60.29871,-102.65334,-108.967964,58.512012,-9.678028,-85.4483,-68.68608,71.5902,-232.42569,91.706856,63.290657,-137.33595,-47.124687,-148.0574,-543.787056,-160.6516,-133.46222,-109.04466,20.916021,-171.20139,-110.596844,67.7301,8.909615,-9.470253,133.29536,-545.897014,-72.91323
1371460-base,-85.56557,-0.493598,-48.374817,-157.98502,96.80951,-81.71021,-22.297688,79.76867,124.357086,105.71518,-149.80756,-54.50168,-21.037973,-24.88766,128.38864,-58.558483,34.862656,19.784412,-130.9182,-79.03223,-166.63525,1507.231274,-8.495993,61.205086,25.895348,813.770071,-140.76886,20.87279,-123.95757,126.34781,11.713674,-125.025154,152.6859,-1018.469545,-22.4446,73.89764,9.190645,-156.51881,-92.18573,-34.92676,-13.277475,16.026424,-33.853546,119.60452,-52.525341,71.20475,-178.70294,-88.2785,30.501453,16.651737,-88.377014,-55.883583,70.18298,-89.233925,92.00578,76.458725,-131.14087,40.914352,-157.90054,-394.319235,-87.107025,-120.772545,-58.82165,41.369606,-132.9345,-43.016839,67.871925,141.77824,69.04852,111.72038,-1111.038833,-23.087206
3438601-base,-105.56409,15.393871,-46.223934,-158.11488,79.514114,-48.94448,-93.71301,38.581398,123.39796,110.324326,-161.188,-68.51979,-0.60733,38.733696,120.74344,-14.109269,28.868027,-29.85881,-94.30395,-79.33981,-138.98427,1507.231274,-131.88538,70.03136,32.736595,813.770071,-62.37086,13.763219,-31.872276,139.5527,9.836465,-150.22113,80.1402,-537.183707,3.091667,129.69933,-63.429424,-169.02724,-119.77007,-28.637785,-8.315162,2.752385,-160.29382,85.08689,-18.25175,90.374054,1.479935,-121.98305,65.85266,8.355225,34.118896,-57.069756,70.4618,-127.90541,94.31428,71.25994,-135.57787,-39.982346,-159.75156,-230.147648,-95.22116,-148.81409,-87.90729,-58.80687,-147.7948,-155.830237,68.974754,21.39751,126.098785,139.7332,-1282.707248,-74.52794
422798-base,-74.63888,11.315012,-40.204174,-161.7643,50.507114,-80.77556,-640.923467,65.225,122.34494,191.46585,-156.98384,-76.65021,-75.67497,12.624029,145.33752,-35.774258,11.598761,-11.460761,-201.35443,-77.779366,-120.9684,548.736883,19.851685,17.943344,27.06332,813.770071,-85.48378,21.236433,-95.07102,132.61092,13.526038,-160.47684,104.71937,-304.174382,-15.385452,91.418655,-36.474556,-157.43959,-102.83162,-56.78271,-19.969252,-0.598189,-222.22879,33.441666,-56.09211,71.27603,-8.713509,-86.09938,8.488903,-14.959278,86.812996,-29.666779,64.417755,-56.716187,80.90172,69.3969,-137.74811,23.644325,-101.447716,-341.115945,22.303604,-131.1714,-30.002094,53.64293,-149.82323,176.921371,69.47328,-43.39518,-58.947716,133.84064,-1074.464888,-1.164146


In [10]:
df_base.duplicated().sum()

11278

In [11]:
df_base.shape

(291813, 72)

In [13]:
df_train = pd.read_csv(df_path+"train.csv", index_col=0)
df_train.head()

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,Target
Id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1
109249-query,-24.021454,3.122524,-80.947525,-112.329994,191.09018,-66.90313,-759.626065,-75.284454,120.55149,131.1317,-149.21106,-102.31221,21.387623,11.277594,143.2214,-22.01157,-3.618249,-16.00548,-133.38228,-78.89356,-65.69053,407.773575,-11.660624,67.00815,24.975033,813.770071,40.051064,17.933155,-75.435745,149.8172,-23.413877,-178.09557,133.78647,-906.571061,113.35556,83.94226,-16.592659,-146.52074,-120.23786,-27.341612,-8.845615,1.027612,-175.64772,167.73582,-32.931559,47.86096,-196.2475,-118.81005,-4.762772,-114.87768,37.397278,-55.616966,56.627056,-108.43317,87.37256,76.51343,-136.27057,3.652915,-164.57451,-635.284275,-75.647255,-116.67934,-41.234684,-24.60167,-167.76077,133.678516,68.1846,26.317545,11.938202,148.54932,-778.563381,-46.87775,66971-base
34137-query,-82.03358,8.115866,-8.793022,-182.9721,56.645336,-52.59761,-55.720337,130.05925,129.38335,76.20288,-137.79942,33.30165,-2.868191,-34.31877,189.06479,-19.33755,-14.20821,-71.110245,-157.74814,-78.70069,-91.741875,1054.2056,-41.84563,102.12862,72.55905,813.770071,-37.957787,17.598982,-159.9754,140.02528,-8.819328,-147.05518,113.81987,-529.295053,70.67494,55.976795,8.817799,-134.14812,-73.679794,-57.566544,-4.338496,-3.270682,-144.4992,144.6502,-37.903276,58.913525,-105.36284,-125.66783,19.367283,-29.087658,-35.02135,26.627962,55.718437,-110.52611,83.513374,75.92613,-135.68242,-7.429803,-180.64502,11.470171,16.464691,-121.807236,-90.81445,54.448433,-120.894806,-12.292085,66.608116,-27.997612,10.091335,95.809265,-1022.691531,-88.564705,1433819-base
136121-query,-75.71964,-0.223386,-86.18613,-162.06406,114.320114,-53.3946,-117.261013,-24.857851,124.8078,112.190155,-200.92596,-38.86518,-80.61127,14.343805,156.62129,-22.498169,-26.359468,-109.03487,-106.92659,-79.74731,-69.87683,1507.231274,-20.058287,34.334927,23.592144,813.770071,-49.50386,22.1662,-85.74016,134.83647,-69.56985,-139.88724,67.377045,-341.781842,54.161224,81.89166,36.421352,-159.99583,-131.91608,-20.495195,-13.976569,-2.355247,-216.22865,238.83649,-56.611536,43.36664,7.191841,-159.48369,-19.338009,-51.409897,36.81954,32.53688,80.68102,-232.40741,84.05369,59.08618,-139.8595,78.40944,-115.940575,2.426572,7.594826,-126.520134,-73.14896,-5.609123,-93.02988,-80.997871,63.733383,11.378683,62.932007,130.97539,-1074.464888,-74.861176,290133-base
105191-query,-56.58062,5.093593,-46.94311,-149.03912,112.43643,-76.82051,-324.995645,-32.833107,119.47865,120.07479,-61.347084,-28.6706,-102.79018,-36.19432,157.18976,-33.31824,7.448413,-47.230713,-178.04608,-78.78652,-106.23544,1507.231274,-63.414307,38.099255,-89.79535,813.770071,-107.43239,10.052701,-71.91738,147.74005,-18.750763,-143.79562,67.20731,-366.139446,112.1877,78.14481,-41.08541,-132.75719,-89.44503,-19.267069,-14.866466,7.775788,-104.30211,74.622894,-59.875136,76.40647,-77.79702,-92.01658,19.3373,-37.922787,37.27127,111.63957,94.91295,-179.7254,86.60148,62.698364,-122.16293,29.87394,-53.50812,-0.938894,-36.919907,-144.555,-96.79859,21.624313,-158.88037,179.597294,69.89136,-33.804955,233.91461,122.868546,-1074.464888,-93.775375,1270048-base
63983-query,-52.72565,9.027046,-92.82965,-113.11101,134.12497,-42.423073,-759.626065,8.261169,119.49023,172.36536,-186.64139,-84.9438,-92.339966,-30.229528,167.86163,-22.635653,0.014536,-9.796367,-213.1018,-78.59006,-98.7283,1250.423749,-43.892487,86.28845,-1.549826,813.770071,-110.35698,24.055641,-96.57827,156.5823,45.12424,-123.888504,118.03511,-607.946912,52.31141,76.7478,-14.161914,-143.53851,-124.886215,-64.78333,-17.706848,15.446568,-53.554455,174.38162,-23.140892,76.41933,-73.357605,-128.12526,-34.57149,-2.756741,44.027752,-13.445387,62.028725,-99.98626,79.376854,49.96618,-131.30576,-71.27052,-262.39697,-21.395427,-43.73464,-127.42511,-81.566216,13.807772,-208.65004,41.742014,66.52242,41.36293,162.72305,111.26131,-151.162805,-33.83145,168591-base


In [14]:
df_train.shape

(9999, 73)

In [16]:
df_valid = pd.read_csv(df_path+"validation.csv", index_col=0)
answers = pd.read_csv(df_path+"validation_answer.csv", index_col=0)
df_valid.shape

(10000, 72)

## 📊 Create FAISS [index](https://github.com/facebookresearch/faiss/wiki/Faiss-indexes) for small dataset


[Guideline](https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index)

Hint: Use numpy [ascontigiousarray](https://numpy.org/doc/stable/reference/generated/numpy.ascontiguousarray.html) - object which is stored in one [unbroken block](https://www.educative.io/answers/what-is-the-numpyascontiguousarray-function-in-python) in memory -  to load vectors in FAISS



In [24]:
dims = df_base.shape[1]
n_cells = 10
quantizer = faiss.IndexFlatL2(dims)
idx_l2 = faiss.IndexIVFFlat(quantizer, dims, n_cells)

In [25]:
%%time
idx_l2.train(np.ascontiguousarray(df_base.values).astype('float32'))
idx_l2.add(np.ascontiguousarray(df_base.values).astype('float32'))

CPU times: user 960 ms, sys: 482 ms, total: 1.44 s
Wall time: 672 ms


In [26]:
base_index = {k: v for k, v in enumerate(df_base.index.to_list())}

## 🔍 Search

In [27]:
targets = df_train["Target"]
df_train.drop("Target", axis=1, inplace=True)

KeyError: 'Target'

In [28]:
%%time
candidate_number = 5
r, idx = idx_l2.search(np.ascontiguousarray(df_train.values).astype('float32'), candidate_number)

CPU times: user 36.8 s, sys: 55.2 ms, total: 36.9 s
Wall time: 9.47 s


## 📈 Accuracy@candidate_number calculation

In [29]:
acc = 0
for target, el in zip(targets.values.tolist(), idx.tolist()):
    acc += int(target in [base_index[r] for r in el])
print(f'Accuracy @ {candidate_number} = {acc / len(idx):.1%}')

Accuracy @ 5 = 14.6%


## ❓❓❓ What's next?

For full dataset it is strongly recommended to test your code on the small batch before loading all dataset to FAISS

You can make your own research:
- change number of cells
- change number of candidates
- change indexes
- add another ML models to improve the FAISS result
- change the accelerator: Hint: Search method on GPU differs a bit from the similar method on CPU
-.....

Remember, that in Colab you have only 12 GB of RAM, so remove variables and objects if necessary

**Good Luck!**