In [1]:
def to_device(x, device="cuda:0"):
    if isinstance(x, dict):
        return {k: to_device(v) for k, v in x.items()}
    return x.to(device=device)

In [2]:
! pip install -qU hazm emoji
# ! pip install img2vec-pytorch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.7/316.7 KB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m240.9/240.9 KB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.6/233.6 KB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m42.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for nltk (setup.py) ... [?25l[?25hdone
  Building wheel for emoji (setup.py) ... [?25l[?25hdone
  Building wheel for libwapiti (setup.py) ... [?25l[?25hdone


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import shutil

src = '/content/drive/MyDrive/base_models'
dst = '/content'
shutil.copytree(src, dst, symlinks=False, ignore=None, ignore_dangling_symlinks=False, dirs_exist_ok=True)

'/content'

In [5]:
import torch
import glob
from PIL import Image
from img2vec import Img2Vec
from model import Corrnet
import numpy as np
import pickle
from utils import predict
from preprocessing import Preprocessor
from utils_word2vec import text_standardization
from IPython.display import display
from IPython.display import Image as Img


In [6]:
class ImageSearchDemo:
    def __init__(self):
        self.img2vec = Img2Vec('resnet-18', 'default', 512, cuda=True)
        self.w2v_weights = np.load("/content/w2v_embedding.npz")['arr_0']
        self.w2v_vocabs = pickle.load(open("/content/vocabs.pkl", "rb"))
        self.preprocessor = Preprocessor()
        self.model_save_path = '/content/model_state.pt'
        self.corrnet = Corrnet(512, 50)
        self.corrnet.load_state_dict(torch.load(self.model_save_path))
        self.corrnet.eval()

    def compute_text_embedding(self, query: str, embedding_dim=512):
        query_embedding = None
        tf_cleaned_input = text_standardization(query, self.preprocessor)

        v = np.array([0. for i in range(embedding_dim)])
        l = 0
        for word in (tf_cleaned_input.numpy()).decode('utf-8').split():
            word = '[UNK]' if word not in self.w2v_vocabs.keys() else word
            v += self.w2v_weights[self.w2v_vocabs[word]]
            l += 1
        query_embedding = v / l

        return query_embedding

    def compute_image_embedding(self, img, embedding_dim=512):
        return self.img2vec.get_vec(img)

    def image_search(self, query: str, image_name_set, top_num):
        query_embedding = self.compute_text_embedding(query=query)
        img_array = np.zeros((len(image_name_set), 512))
        for i in range(len(image_name_set)):
            img = Image.open(image_name_set[i]).convert('RGB')
            img_array[i] = self.img2vec.get_vec(img)

        txt_array = np.zeros((len(image_name_set), 512))
        for j in range(len(image_name_set)):
            txt_array[j] = query_embedding

        predictions = list(
            predict(self.corrnet, torch.from_numpy(txt_array.astype(np.float32)), torch.from_numpy(img_array.astype(np.float32)))[1])

        predictions_dict = dict(zip(image_name_set, predictions))
        predictions_dict = dict(sorted(predictions_dict.items(), key=lambda item: item[1]))
        if len(list(predictions_dict.keys())) > top_num:
            return list(predictions_dict.keys())[0:top_num]
        return predictions_dict.keys()

baseline = ImageSearchDemo()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

# STL10 dataset

In [None]:
from torchvision.datasets import STL10

stl10_dataset = STL10(root='STL10-dataset', download=True, folds=1)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to STL10-dataset/stl10_binary.tar.gz


  0%|          | 0/2640397119 [00:00<?, ?it/s]

Extracting STL10-dataset/stl10_binary.tar.gz to STL10-dataset


In [None]:
label_to_text =  {
      0: 'این عکس یک هواپیما است.',
      1: 'این عکس یک پرنده است.',
      2: 'این عکس یک ماشین است.',
      3: 'این عکس یک گربه است.',
      4: 'این عکس یک آهو است.',
      5: 'این عکس یک سگ است.',
      6: 'این عکس یک اسب است.',
      7: 'این عکس یک میمون است.',
      8: 'این عکس یک کشتی است.',
      9: 'این عکس یک کامیون است.'
}


### Text Embedding

In [None]:
! pip install -qU tqdm

In [None]:
from tqdm import tqdm

text_embeddings = []
for text in tqdm(label_to_text.values()):
    text_embeddings.append(baseline.compute_text_embedding(text).tolist())

100%|██████████| 10/10 [00:00<00:00, 19.67it/s]


### Image Embedding

In [None]:
len(stl10_dataset)

1000

In [None]:
image_embeddings = []
for image, _ in tqdm(stl10_dataset):
    image_embeddings.append(baseline.compute_image_embedding(image).tolist())

100%|██████████| 1000/1000 [00:13<00:00, 74.79it/s]


## Pairwise Cosine Similarity

In [None]:
print(len(image_embeddings))
print(len(text_embeddings))

1000
10


In [None]:
import numpy as np
from numpy import dot
from numpy.linalg import norm

def cosine_similarity(emb1, emb2):
    cos_sim = dot(emb1, emb2) / (norm(emb1) * norm(emb2))
    return cos_sim

In [None]:
cosine_matrix = []
for image_embedding in tqdm(image_embeddings):
    cosine_matrix.append([cosine_similarity(image_embedding, text_embedding) for text_embedding in text_embeddings ])

100%|██████████| 1000/1000 [00:01<00:00, 588.92it/s]


In [None]:
predicted_labels = []
true_labels = []


for i in range(len(cosine_matrix)):
  predicted_labels.append(cosine_matrix[i].index(max(cosine_matrix[i])))
  true_labels.append(stl10_dataset[i][1])

In [None]:
from sklearn.metrics import classification_report


report = classification_report(true_labels, predicted_labels)
print(report)

              precision    recall  f1-score   support

           0       0.22      0.31      0.26       100
           1       0.50      0.01      0.02       100
           2       0.00      0.00      0.00       100
           3       0.00      0.00      0.00       100
           4       0.00      0.00      0.00       100
           5       0.10      0.21      0.13       100
           6       0.00      0.00      0.00       100
           7       0.08      0.52      0.15       100
           8       0.00      0.00      0.00       100
           9       0.00      0.00      0.00       100

    accuracy                           0.10      1000
   macro avg       0.09      0.11      0.06      1000
weighted avg       0.09      0.10      0.06      1000



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# OxfordIIITPet

In [None]:
from torchvision.datasets import OxfordIIITPet

In [None]:
pet_dataset = OxfordIIITPet(root='pet-dataset', download=True)

Downloading https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz to pet-dataset/oxford-iiit-pet/images.tar.gz


  0%|          | 0/791918971 [00:00<?, ?it/s]

Extracting pet-dataset/oxford-iiit-pet/images.tar.gz to pet-dataset/oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/~vgg/data/pets/annotations.tar.gz to pet-dataset/oxford-iiit-pet/annotations.tar.gz


  0%|          | 0/19173078 [00:00<?, ?it/s]

Extracting pet-dataset/oxford-iiit-pet/annotations.tar.gz to pet-dataset/oxford-iiit-pet


In [None]:
dog_text = 'این عکس یک سگ، یک نوع حیوان خانگی است.'
cat_text = 'این عکس یک گربه، یک نوع حیوان خانگی است.'


pets_text = [dog_text, cat_text]


pet_label_transform =  {
    0: 1,
    1: 0,
    2: 0,
    3: 0,
    4: 0,
    5: 1,
    6: 1,
    7: 1,
    8: 0,
    9: 1,
    10: 0,
    11: 1,
    12: 0,
    13: 0,
    14: 0,
    15: 0,
    16: 0,
    17: 0,
    18: 0,
    19: 0,
    20: 1,
    21: 0,
    22: 0,
    23: 1,
    24: 0,
    25: 0,
    26: 1,
    27: 1,
    28: 0,
    29: 0,
    30: 0,
    31: 0,
    32: 1,
    33: 1,
    34: 0,
    35: 0,
    36: 0,
}

# Text Embedding

In [None]:
from tqdm import tqdm

text_embeddings = []
for text in tqdm(pets_text):
    text_embeddings.append(baseline.compute_text_embedding(text).tolist())

100%|██████████| 2/2 [00:00<00:00, 865.52it/s]


# Image Embedding

In [None]:
image_embeddings = []
for image, _ in tqdm(pet_dataset):
    image_embeddings.append(baseline.compute_image_embedding(image).tolist())

100%|██████████| 3680/3680 [00:50<00:00, 73.21it/s]


# Pairwise Cosine Similarity

In [None]:
print(len(image_embeddings))
print(len(text_embeddings))

3680
2


In [None]:
import numpy as np
from numpy import dot
from numpy.linalg import norm

def cosine_similarity(emb1, emb2):
    cos_sim = dot(emb1, emb2) / (norm(emb1) * norm(emb2))
    return cos_sim

In [None]:
cosine_matrix = []
for image_embedding in tqdm(image_embeddings):
    cosine_matrix.append([cosine_similarity(image_embedding, text_embedding) for text_embedding in text_embeddings ])

100%|██████████| 3680/3680 [00:01<00:00, 2286.43it/s]


In [None]:
predicted_labels = []
true_labels = []

for i in range(len(cosine_matrix)):
  predicted_labels.append(cosine_matrix[i].index(max(cosine_matrix[i])))
  true_labels.append(pet_label_transform[pet_dataset[i][1]])


## Evaluation

In [None]:
from sklearn.metrics import classification_report


report = classification_report(true_labels, predicted_labels)
print(report)

              precision    recall  f1-score   support

           0       0.68      1.00      0.81      2492
           1       0.00      0.00      0.00      1188

    accuracy                           0.68      3680
   macro avg       0.34      0.50      0.40      3680
weighted avg       0.46      0.68      0.55      3680



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
