## FAISS for Animal Image Recognition

In [2]:
import os
import numpy as np
from PIL import Image

from faiss_research.modules.preprocessing import create_img_dict, create_img_list
from faiss_research.modules.embedding import ImageEmbedding
from faiss_research.modules.faiss_search import FaissSearch

### I. Load and Preprocess Images

In [4]:
# Load images from folder
folder_path = 'storage/animal_images/training'
animal_dict = create_img_dict(folder_path)
animal_img_list, labels, animal_names = create_img_list(animal_dict, folder_path)
print('Number of animal images:', len(animal_img_list))

Number of animal images: 600


In [5]:
# Convert images to vectors
preprocessed_img_list = [img for img, name in animal_img_list]
img_embed = ImageEmbedding()
training_dataset = img_embed.embed_list(preprocessed_img_list)

### II. Initiate and Train FAISS Model

In [None]:
# Initiate FAISS
animal_search = FaissSearch()

# 1. Flat Search
animal_search.flat_search(training_dataset, labels, animal_names)
save_dir = 'faiss_research/pickle_files/Flat_index.pkl'
animal_search.save(save_dir)
print('Flat index file saved successfully.')

# 2. LSH Search
animal_search.lsh_search(training_dataset, labels, animal_names)
save_dir = 'faiss_research/pickle_files/LSH_index.pkl'
animal_search.save(save_dir)
print('LSH index file saved successfully')

# 3. HNSW Search
animal_search.hnsw_search(training_dataset, labels, animal_names)
save_dir = 'faiss_research/pickle_files/HNSW_index.pkl'
animal_search.save(save_dir)
print('HNSW index file saved successfully.')

Flat index file saved successfully.


### III. Test with Image Query

### i. Define Query Function

In [None]:
# Define image test function
def test_images(index, test_folder, ground_truth):
    test_img_files = []
    for root, dirs, files in os.walk(test_folder):
        test_img_files += [*files]

    img_embed = ImageEmbedding()
    k = 10
    detc_labels = []
    for img_file in test_img_files:
        try:
            img = Image.open(os.path.join(test_folder, img_file)).convert('RGB')
        except:
            continue

        img_q = img_embed.embed_list([img])
        label, D, I = index.search(img_q, k)
        detc_labels.append(label)

    print(f'Accuracy: {sum([1 if detc_labels[i]==ground_truth[i] else 0 for i in detc_labels])/len(detc_labels)*100:.2f}%')
    print('\n')

### ii. Initiate Test

In [None]:
# Initiate Test
test_folder = 'storage/animal_images/query'
gt_file = os.path.join(test_folder, 'image_labels.txt')
with open(gt_file, 'r') as f:
    gt = [line.strip() for line in f]

print("Ground Truths:", gt)

animal_search = FaissSearch()

#### A. Flat Search

In [None]:
# Flat Search:
print("Faiss Flat Search:")
load_dir = 'faiss_research/pickle_files/Flat_index.pkl'

# Queries:
saved_index = animal_search.load(load_dir)
test_images(saved_index, test_folder, gt)

#### B. LSH Search

In [None]:
# LSH Search:
print("Faiss LSH Search:")
animal_search = FaissSearch()
load_dir = 'faiss_research/pickle_files/LSH_index.pkl'

# Queries:
saved_index = animal_search.load(load_dir)
test_images(saved_index, test_folder, gt)

#### C. HNSW Search

In [None]:
# HNSW Search:
print("Faiss HNSW Search:")
animal_search = FaissSearch()
load_dir = 'faiss_research/pickle_files/HNSW_index.pkl'

# Queries:
saved_index = animal_search.load(load_dir)
test_images(saved_index, test_folder, gt)