In [2]:
!pip install annoy tqdm # -i https://opentuna.cn/pypi/web/simple
import time
import torch
import numpy as np
import sys
import os
import requests
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from io import BytesIO
import pandas as pd

from scipy.spatial import distance

import pickle
from annoy import AnnoyIndex
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
import tqdm
import sagemaker
print(sagemaker.__version__)

from sagemaker.pytorch import PyTorch, PyTorchPredictor
from sagemaker.pytorch.model import PyTorchModel

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p37/bin/python -m pip install --upgrade pip' command.[0m
2.77.1


In [None]:
def get_index_mae(model_name, seq_net):
    ann_filename = 'index_advance_'+model_name+'.ann'
    keys_filename = 'keys_'+model_name+'.pkl'
    if os.path.exists(ann_filename) and os.path.exists(keys_filename):
        try:
            f = 768  # TODO may change
            u = AnnoyIndex(f, 'euclidean')
            u.load(ann_filename)
        except:
            f = 2048  # TODO may change
            u = AnnoyIndex(f, 'euclidean')
            u.load(ann_filename)
        keys = pickle.load(open(keys_filename, 'rb'))
        print('keys:', len(keys), keys[:5])
        return u, keys
    
    start = time.time()
    keys = []
    embs = []
    key=0
    f = 768
    u = AnnoyIndex(f, 'euclidean')
    base_sub_dirs = os.listdir(base_dir)
    for sub_dir in tqdm.tqdm(base_sub_dirs):
        sub_dir = os.path.join(base_dir, sub_dir)
        filenames = os.listdir(sub_dir)
        for filename in filenames:
            if not filename.endswith('.jpg') and not filename.endswith('.png'):
#                 print(os.path.join(sub_dir, filename))
                continue
            filename = os.path.join(sub_dir, filename)
#             key = filename
            try:
                img = cv2.imread(filename)
                emb=seq_net.predict(img)
                if f is None:
                    print('shape:', emb.shape)
                    if len(emb.shape) > 1 and emb.shape[1] != 1:
                        return None, None
                    f = emb.shape[0]
                u.add_item(key, emb)
                key+=1
            except Exception as e:
                print(filename, e)
                continue
            keys.append(filename)
#             embs.append(emb)
    end = time.time()
    print('get_embedding time:', (end-start)/len(keys), len(keys))
    
#     f = 768  # 512/2048
#     u = AnnoyIndex(f, 'euclidean')  # Length of item vector that will be indexed
#     for key, value in tqdm.tqdm(enumerate(embs)):
#         u.add_item(key, value)
    u.build(100) # 100 trees
    u.save(ann_filename)
    pickle.dump(keys, open(keys_filename, 'wb'))
    
    return u, keys

def evaluate_mae(u, keys, model_name, seq_net):
    # ps = ['YongKang_img0.jpg', 'YongKang_img1.jpg', 'YongKang_img3.jpg', 'YongKang_img4.jpg']
    query_sub_dirs = os.listdir(query_dir)
    ps = []
    for query_sub_dir in query_sub_dirs:
        query_sub_dir = os.path.join(query_dir, query_sub_dir)
        sub_ps = os.listdir(query_sub_dir)
        for sub_p in sub_ps:
            if sub_p.endswith('jpg') or sub_p.endswith('png'):
                ps.append(os.path.join(query_sub_dir, sub_p))
    filenames = []
    min_sims = []
    min_sim_filenames = []
    sift_sims = []
    for p in tqdm.tqdm(ps):
        try:
            img = cv2.imread(filename)
            emb=seq_net.predict(img)
        except Exception as e:
            print(p, e)
            continue
#         print(p)
        comparisons = u.get_nns_by_vector(target_feature, K)
        min_sim = 1
        min_sim_filename = ''
        for i, comparison in enumerate(comparisons):
            if p != keys[comparison]:
                min_sim_filename = keys[comparison]
        sift_sim = 0
        filenames.append(p)
        min_sims.append(min_sim)
        min_sim_filenames.append(min_sim_filename)
        sift_sims.append(sift_sim)
        
    result = pd.DataFrame({'filename': filenames, 'min_sim_filename': min_sim_filenames, 'min_sim': min_sims, 'sift_sim': sift_sims})
    result.to_excel('result_'+model_name+'.xlsx')
    
    result_eval = result
    result_eval['label'] = result_eval['filename'].str.split('/', expand=True)[5]  # TODO may change
    result_eval['pred'] = result_eval['min_sim_filename'].str.split('/', expand=True)[5]  # TODO may change
    result_eval['correct'] = result_eval['label'] == result_eval['pred']
    result_eval.to_excel('result_eval_'+model_name+'.xlsx')
    
    return result_eval['correct'].sum()

# 加载endpoint

In [5]:
predictor=PyTorchPredictor(
    endpoint_name='pytorch-inference-2022-03-18-06-53-11-274',
)
model=predictor

# 生成随机投影森林

In [None]:
base_dir = 'data/haitian/haitian_recognition/644_final/images'
query_dir = 'data/haitian/haitian_recognition/644_final/images'

print('Model loaded.')
u, keys=get_index_mae('mae_vit_base_pretrain', model)

# 推理+搜索结果+评估

In [None]:
model_name='mae_vit_base_pretrain'
K = 2  # 3/6/20/40
base_dir = 'data/haitian/haitian_recognition/644_final/images'
query_dir = 'data/haitian/haitian_recognition/644_final/images'


u, keys=get_index_mae(model_name, model)
correct_sum = evaluate_mae(u, keys, model_name, model)
print('correct_sum:', correct_sum)