In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install timm
!pip install faiss-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.7-py3-none-any.whl (509 kB)
[K     |████████████████████████████████| 509 kB 4.2 MB/s 
Installing collected packages: timm
Successfully installed timm-0.6.7
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[K     |████████████████████████████████| 85.5 MB 108 kB/s 
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [3]:
!gdown --id 1--a-r9-mUqV1C3jIcc4edA5X-ipOaMN6
!unzip ./preproc_images.zip
!gdown --id 1nnzAz0OZx3OPDauN2R7H6aTUnkNb2eYT

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: images/plant_Dogwood/0.48331093504660283.jpg  
  inflating: images/plant_Dogwood/0.0162917200616709.jpg  
  inflating: images/plant_Dogwood/0.6062354001262471.jpg  
  inflating: images/plant_Dogwood/0.4905092149345899.jpg  
  inflating: images/plant_Dogwood/0.9600715745087613.jpg  
  inflating: images/plant_Dogwood/0.20333053148726865.jpg  
  inflating: images/plant_Dogwood/0.8678129479911922.jpg  
  inflating: images/plant_Dogwood/0.3882028634625333.jpg  
  inflating: images/plant_Dogwood/0.8466237244257693.jpg  
  inflating: images/plant_Dogwood/0.47128305837683626.jpg  
  inflating: images/plant_Dogwood/0.27629993980053735.jpg  
   creating: images/animal_Coral/
  inflating: images/animal_Coral/0.3649292917081799.jpg  
  inflating: images/animal_Coral/0.3135225550634376.jpg  
  inflating: images/animal_Coral/0.1584579438838437.jpg  
  inflating: images/animal_Coral/0.7317349163192701.jpg  
  inflating: ima

## I. Load a Pretrained Model

In [None]:
# import timm
# from pprint import pprint
# model_names = timm.list_models(pretrained=True)
# pprint(model_names)

In [4]:
import timm
import sys
import torch
import cv2
import numpy as np
import copy
import random
from PIL import Image
import faiss
from tqdm import tqdm, tqdm_notebook
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [5]:
class Model():
  def __init__(self, model_name, size=(224,224)):
    self.model = timm.create_model(str(model_name), pretrained=True)
    self.size = size

  def __call__(self, image_path):
    image = Image.open(str(image_path)).resize(self.size)
    image = torch.as_tensor(np.array(image, dtype=np.float32)).transpose(2,0)[None]
    return self.model(image)

In [6]:
xception41_model = Model('xception41')
output_xception41 = xception41_model('/content/apple.jpg')
print(output_xception41.shape)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth" to /root/.cache/torch/hub/checkpoints/tf_xception_41-e6439c97.pth


torch.Size([1, 1000])


In [7]:
resnest50d_model = Model('resnet50')
output_resnet50 = resnest50d_model('/content/apple.jpg')
print(output_resnet50.shape)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth" to /root/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth


torch.Size([1, 1000])


In [8]:
vgg16_model = Model('vgg16')
output_vgg16 = vgg16_model('/content/apple.jpg')
print(output_vgg16.shape)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


torch.Size([1, 1000])


In [9]:
mobilenetv3_model = Model('mobilenetv3_rw')
output_mobilenetv3 = mobilenetv3_model('/content/apple.jpg')
print(output_mobilenetv3.shape)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth" to /root/.cache/torch/hub/checkpoints/mobilenetv3_100-35495452.pth


torch.Size([1, 1000])


## II. Build model Image Retrieval

### 2.1 Way 1: Concat

In [10]:
def concat_model(image):
  o1 = resnest50d_model(image)
  o2 = xception41_model(image)
  o3 = vgg16_model(image)
  concat = torch.concat((o1, o2, o3))
  return torch.sum(concat, dim=0, keepdims=True)

### Image Retrival with Cosine

In [None]:
from PIL import Image # Đọc ảnh
import numpy as np # Xử lý ma trận
import os # Thao tác lấy tên file, move file của OS
import matplotlib.pyplot as plt

In [None]:
dic_categories = ['scenery', 'furniture', 'animal', 'plant'] 
root_img_path = '/content/images'
files = []
for folder in os.listdir(root_img_path):
  if folder.split("_")[0] in dic_categories:
    path = os.path.join(root_img_path, folder)
    list_dir = [name for name in glob.glob(path+'/*') if name.endswith((".jpg", ".png", ".jpeg"))]
    for file in list_dir:
      files.append(file)
      
id2img_fps = dict(enumerate(files))

In [None]:
def cosine_similarity(query, X):
    axis_batch_size = tuple(range(1,len(X.shape)))
    query_norm = np.linalg.norm(query)
    X_norm = np.asarray([np.linalg.norm(i) for i in X])
    return np.sum(X*query, axis=axis_batch_size)/(query_norm*X_norm)

In [None]:
def folder_to_images(list_dir,size):    
  i = 0
  images_np = np.zeros(shape=(len(list_dir), *size))
  idx = []
  for key,img in list_dir.items():
  # try:
    idx.append(key)
    images_np[i] = concat_model(img).detach().numpy()
    i += 1
        # except Exception:
        #     print("error: ", img)
#             os.remove(path)
    return images_np,idx

In [None]:
# Create Index
ls_path_score = []
query_path = '/content/apple.jpg'
query = mobilenetv3_model(query_path).detach().numpy()
size = (1,1000)

for img_index, img in tqdm_notebook(id2img_fps.items()):
  path = root_img_path + folder
  images_np, idx = folder_to_images(id2img_fps, size)
  rates = cosine_similarity(query, images_np)
  ls_path_score.extend(list(zip(idx, rates)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


  0%|          | 0/18775 [00:00<?, ?it/s]

  """


In [None]:
def get_l1_score(root_img_path, query_path, size):
    dic_categories = ['scenery', 'furniture', 'animal', 'plant'] 
    query = read_image_from_path(query_path, size)
    ls_path_score = []
    for folder in os.listdir(root_img_path):
        if folder.split("_")[0] in dic_categories:
            path = root_img_path + folder
            images_np, images_path = folder_to_images(path, size) # mang numpy nhieu anh, paths
            rates = cosine_similarity(query, images_np)
            ls_path_score.extend(list(zip(images_path, rates)))
    return query, ls_path_score

In [None]:
def plot_results(query, ls_path_score):
    # Show query image
    plt.imshow(query/255.0)
    # Score
    fig = plt.figure(figsize=(15, 15))
    columns = 5
    rows = 6
    for i, path in enumerate(sorted(ls_path_score, reverse=True, key=lambda x : x[1])[:30], 1):
        img = np.random.randint(10, size=(10,10))
        fig.add_subplot(rows, columns, i)
        plt.imshow(plt.imread(path[0]))
        plt.axis("off")
    plt.show()

In [None]:
root_img_path = "/content/images/"
query_path = "apple.jpg"
size = (80, 80)
query, ls_path_score = get_l1_score(root_img_path, query_path, size)
plot_results(query, ls_path_score)

### Image Retrieval with Faiss

In [11]:
dic_categories = ['scenery', 'furniture', 'animal', 'plant'] 
root_img_path = '/content/images'
files = []
for folder in os.listdir(root_img_path):
  if folder.split("_")[0] in dic_categories:
    path = os.path.join(root_img_path, folder)
    list_dir = [name for name in glob.glob(path+'/*') if name.endswith((".jpg", ".png", ".jpeg"))]
    for file in list_dir:
      files.append(file)
      
id2img_fps = dict(enumerate(files))

In [None]:
# Create Index
index = faiss.IndexFlatL2(1000)

fea_indexes = []
error_indexes = []

for img_index, img in tqdm_notebook(id2img_fps.items()):
  # try:
  embedded = concat_model(img).detach().numpy()
  index.add(embedded)
  fea_indexes.append(img_index)
  # except Exception:
  #   print('error')
  #   error_indexes.append(img_index)
  #   continue



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


  0%|          | 0/18775 [00:00<?, ?it/s]

In [None]:
#### Save bin file ####
faiss.write_index(index, os.path.join('/content/drive/MyDrive/AIC_2022/CNN_Extractions', "faiss.bin"))

In [None]:
#### Or use can load Index by bin file ####
# index = faiss.read_index(os.path.join('./', "oxbuild_images-v1.bin"))

In [None]:
# Query model
query_path = '/content/apple.jpg'
query = mobilenetv3_model(query_path).detach().numpy()
f_dists, f_ids = index.search(query, k=7)
print(f"scores: {f_dists[0]}")
print(f"idx: {f_ids[0]}")
result_ids = f_ids[0][1:]

scores: [118659.734 158241.33  173093.48  194436.83  195462.8   207810.44
 212956.53 ]
idx: [ 421  668 1274 1030  530 1146  488]


In [None]:
def show_imgs(query, f_ids):
  plt.imshow(query)
  fig = plt.figure(figsize=(12, 12))
  columns = 3
  rows = 2
  for i in range(1, columns*rows +1):
    img = mpimg.imread(id2img_fps[f_ids[i - 1]])
    ax = fig.add_subplot(rows, columns, i)
    plt.imshow(img)
    plt.axis("off")
  plt.show()

In [None]:
query = Image.open(query_path) 
show_imgs(query, result_ids)