In [1]:
import numpy as np
import pandas as pd
from elasticsearch import Elasticsearch
from img2vec_pytorch import Img2Vec
import os
import re
from PIL import Image

In [2]:
index = 'plant_index'
base_image_path = "data/imgs/train_images"

In [3]:
img2vec = Img2Vec(cuda=True)



In [4]:
elastic_client = Elasticsearch(hosts=['http://localhost:9200'])

In [5]:
main_path = "data/imgs/test"
def get_files():
    images_files = []
    for root,dirs,files in os.walk(main_path):
        for file in files:
            if file.endswith(".jpg") or file.endswith(".png") or file.endswith(".jpeg"):
                images_files.append(os.path.join(root, file))
    return images_files        
                

In [6]:
def generate_embedding(filepath):
    img = Image.open(filepath)
    vec = img2vec.get_vec(img)
    return vec.tolist()


In [7]:
def getSimilarity(vector:list, embedding_field:str, index_name:str, size:int, k:int, candidate:int):
    result = elastic_client.search(
        index=index_name,
        body={
            "size": size,
            "knn": {
            "field": "{}".format(embedding_field),
            "query_vector": vector,
            "k": k,
            "num_candidates": candidate
            },
        "fields": ["label"],
        "_source": "false"
        }
    )
    return result

In [8]:
#Estou usando a função abaixo apenas para remover as partes de diretório da string do arquivo, para que então
#Possa utilizar ela para buscar a classe do arquivo no dataframe
df = pd.read_csv("data/imgs/labels.csv")
def get_label(name):
  name_hash = name.split(".")[0]
  name_hash = re.sub(r'.*\\', '', name_hash)
  name_for_es = name_hash + ".jpg"
  labels = df.loc[df['image'] == name_for_es, 'labels'].values[0]
  return labels

In [9]:
def frequency_histogram(subclasses:list, k:int):
    result = {}
    ctr=1
    for value, key in sorted(((subclasses.count(e), e) for e in set(subclasses)), reverse=True):
        #print(" Value "  + str(value) + " Key "  + str(key))
        if (ctr > k): break
        result[key] = value
        ctr+=1

    return result

In [10]:
def process_result(accuracy_dict, k, n, type):
    index = "{}-{}-{}".format(k,n,type)
    #print(index,type,k,n)
    if (index in accuracy_dict): 
        accuracy_dict[index] = accuracy_dict.get(index) + 1
    else:
        accuracy_dict[index] = 1

In [11]:
def get_process_result(accuracy_dict, k, n, type):
    index = "{}-{}-{}".format(k,n,type)
    if (index in accuracy_dict): 
        return accuracy_dict[index]
    else:
        return 0

In [12]:
def print_process_result(accuracy_dict, k_list, n_list):
    for k in k_list:
        for n in n_list:
            positive = get_process_result(accuracy_dict, k, n, 'positive')
            negative = get_process_result(accuracy_dict, k, n, 'negative')
            accuracy = positive / (positive + negative)
            print("k={} - n={} - Positive: {} - Negative: {} - " 
                "Accuracy: {} ".format(k,n,positive,negative,accuracy))

In [13]:
def transform_process_result(accuracy_dict, k_list, n_list):
    matrix = np.zeros((len(k_list), len(n_list)))
    i = j = 0
    for k in k_list:
        j=0
        for n in n_list:
            positive = get_process_result(accuracy_dict, k, n, 'positive')
            negative = get_process_result(accuracy_dict, k, n, 'negative')
            accuracy = positive / (positive + negative)
            matrix[i][j] = accuracy
            j+=1
        i+=1
    return matrix

In [14]:
def central():
    k_list = [1,2,3,4,5,6,7,8,9,10]
    n_list = [1,2,3,4,5,10,25,50,75,100]

    max_n = 100
    candidate = 100

    files = get_files()

    #jumping from 5 to 5 in the files list to get a lower amount just to test
    #files = files[::10]

    accuracy_dict = {}
    id = 0

    print("Starting the process")

    for file in files:
        id += 1

        hit_list = []
        labels_list = []

        img_file = file #already has the path

        labels = get_label(img_file)

        vec = generate_embedding(img_file)

        result = getSimilarity(vec, "embedding", index, max_n,max_n, candidate)

        hit_list.clear()
        hits = 0
        for hit in result['hits']['hits']:
            try:
                #print(hit['fields']['label'])
                hit_list.append(hit['fields']['label'])
            except:
                print("Error")
            hits += 1
        
        print("Query id: "+str(id)+" - Label: "+ labels +" - Hits: "+str(hits)) 

        for k in k_list:
            for n in n_list:
                ctr_hit = 0
                for sub in hit_list:
                    ctr_hit += 1
                    if ctr_hit > n: break
                    labels_list.extend(sub)
                
                histogram = frequency_histogram(labels_list, k)
                labels_list.clear()
                #o historiograma nos retorna uma dicionário com cada classe e a quantidade de vezes que ela aparece
                #estamos variando n que é a quantidade de resultados que estamos pegando do elastic 
                #e k que é a quantidade de classes que estamos pegando do histograma

                #logo inicialmente damos 1 classe com 10 a 25 a 50 a 100 resultados
                #e aumentando a quantidade de classes

                #print(n)
                #print(histogram)


                #print(accuracy_dict)
                #print(labels)

                #como nosso label é somente um não usamos um loop
                if(labels in histogram):
                        process_result(accuracy_dict, k, n, "positive")
                else:
                       process_result(accuracy_dict, k, n, "negative")
    
    print("Finished the process")
    print_process_result(accuracy_dict, k_list, n_list)
    matrix = transform_process_result(accuracy_dict, k_list, n_list)
    print("Accuracy Matrix:")
    print(matrix)
    

In [15]:
central()

Starting the process


  result = elastic_client.search(


Query id: 1 - Label: complex - Hits: 100
Query id: 2 - Label: complex - Hits: 100
Query id: 3 - Label: complex - Hits: 100
Query id: 4 - Label: complex - Hits: 100
Query id: 5 - Label: complex - Hits: 100
Query id: 6 - Label: complex - Hits: 100
Query id: 7 - Label: complex - Hits: 100
Query id: 8 - Label: complex - Hits: 100
Query id: 9 - Label: complex - Hits: 100
Query id: 10 - Label: complex - Hits: 100
Query id: 11 - Label: complex - Hits: 100
Query id: 12 - Label: complex - Hits: 100
Query id: 13 - Label: complex - Hits: 100
Query id: 14 - Label: complex - Hits: 100
Query id: 15 - Label: complex - Hits: 100
Query id: 16 - Label: complex - Hits: 100
Query id: 17 - Label: complex - Hits: 100
Query id: 18 - Label: complex - Hits: 100
Query id: 19 - Label: complex - Hits: 100
Query id: 20 - Label: complex - Hits: 100
Query id: 21 - Label: complex - Hits: 100
Query id: 22 - Label: complex - Hits: 100
Query id: 23 - Label: complex - Hits: 100
Query id: 24 - Label: complex - Hits: 100
Q