<a href="https://colab.research.google.com/github/PacktPublishing/Modern-Computer-Vision-with-PyTorch-2E/blob/main/Chapter18/vector_stores.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
try:
  from torch_snippets import *
except:
  %pip install torch-snippets gitPython lovely-tensors
  from torch_snippets import *

from git import Repo

repository_url = 'https://github.com/sizhky/quantization'
destination_directory = '/content/quantization'
if exists(destination_directory):
  repo = Repo(destination_directory)
else:
  repo = Repo.clone_from(repository_url, destination_directory)

%cd {destination_directory}
%pip install -qq -r requirements.txt # this will take about 5 min of time
%pip install -U torchvision
# print(repo.git.pull('origin', 'main'))

# Train

In [2]:
# Change to `Debug=false` in the line below
# to train on a larger dataset
%env DEBUG=true
!make train

env: DEBUG=true
python -m src.defect_classification.train
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100% 528M/528M [00:07<00:00, 70.3MB/s]
Downloading readme: 100% 495/495 [00:00<00:00, 2.54MB/s]
Downloading data: 100% 306M/306M [00:07<00:00, 42.0MB/s]
Downloading data: 100% 305M/305M [00:07<00:00, 42.9MB/s]
Downloading data: 100% 263M/263M [00:08<00:00, 31.4MB/s]
Generating train split: 100% 2331/2331 [00:02<00:00, 880.94 examples/s]
Generating valid split: 100% 1004/1004 [00:00<00:00, 1023.24 examples/s]
Class Balance
 
```↯ AttrDict ↯
train
  non_defect - [1;36m50[0m [1m([0mint[1m)[0m
  defect - [1;36m50[0m [1m([0mint[1m)[0m
valid
  non_defect - [1;36m50[0m [1m([0mint[1m)[0m
  defect - [1;36m50[0m [1m([0mint[1m)[0m

```

Map: 100% 100/100 [00:20<00:00,  4.96 examples/s]
Map: 100% 100/100 [00:18<00:00,  5.28 examples/s]
Epoch: [1;36m1[0m [33mtrain_epoch_loss[0m=[1;36m0[0

# Vector Store

In [3]:
from torch_snippets import *
from src.defect_classification.train import get_datasets, get_dataloaders

trn_ds, val_ds = get_datasets(DEBUG=True)
trn_dl, val_dl = get_dataloaders(trn_ds, val_ds)

model = torch.load('model.pth').cuda().eval()


Access to the secret `HF_TOKEN` has not been granted on this notebook.
You will not be requested again.
Please restart the session if you want to be prompted again.
  # => warn user but ignore error => do not re-request access to user


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [4]:
results = []
for ix, batch in enumerate(iter(trn_dl)):
  inter = model.avgpool(model.features(batch[0].cuda()))[:,:,0,0].detach().cpu().numpy()
  results.append(inter)
results = np.array(results)
results = results.reshape(-1, 512)

In [6]:
import faiss
import numpy as np

index = faiss.IndexFlatL2(results.shape[1])  # L2 distance
index.add(results)
faiss.write_index(index, "index_file.index")
im = val_ds[0]['image'][None].cuda()
tmp = np.array(model.avgpool(model.features(im))[0,:,0,0].detach().cpu().numpy())
query_vector = tmp.reshape(1,512).astype('float32')

In [7]:
%%time
k = 3  # Number of nearest neighbors to retrieve
D, I = index.search(query_vector.astype('float32'), k)

CPU times: user 194 µs, sys: 0 ns, total: 194 µs
Wall time: 200 µs


## Vector Store on 960k vectors instead of 96

In [8]:
vectors = np.array(results.tolist()*10000, dtype=np.float32)
print(vectors.shape)
index = faiss.IndexFlatL2(vectors.shape[1])  # L2 distance
index.add(vectors)
faiss.write_index(index, "index_file_960k.index")

Searching for `query_vector` using vector index takes 673ms

In [9]:
%%time
k = 3  # Number of nearest neighbors to retrieve
D, I = index.search(query_vector.astype('float32'), k)

CPU times: user 670 ms, sys: 2.15 ms, total: 672 ms
Wall time: 673 ms


Searching for `query_vector` using numpy takes 7 seconds

In [11]:
%%time
distances = np.sum(np.square(query_vector - vectors), axis=1)
sorted_distances = np.sort(distances)

CPU times: user 1 s, sys: 4.56 s, total: 5.57 s
Wall time: 7 s
