In [None]:
from transformers import AutoImageProcessor, ViTModel
import torch
from PIL import Image
import requestsimport os
import matplotlib.pyplot as plt


url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state[:,0]
list(last_hidden_states.shape)

In [None]:
def extract_feature(img_dir, image_processor, model):
    feats = {}
    files = os.listdir(img_dir)
    
    for f in files:
        f_path = os.path.join(img_dir,f)
        image = Image.open(f_path)
        print("extracting feats of {}".format(f_path))
        inputs = image_processor(image, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)

            last_hidden_states = outputs.last_hidden_state[:,0]
            # feats.append(last_hidden_states)
            feats[f] = last_hidden_states
        print("success extracting feats of {}".format(f_path))
    return feats

def get_result(ori_feats, back_feats, topn):
    result = {}
    for ori_f in ori_feats.keys():
        dist = {}
        for back_f in back_feats.keys():
            d = torch.nn.functional.cosine_similarity(ori_feats[ori_f], back_feats[back_f])
            # d = torch.ao.ns.fx.utils.compute_cosine_similarity(ori_feats[ori_f], back_feats[back_f])
            dist[back_f] = d
        dist = sorted(dist.items(), key=lambda d:d[1], reverse = True)
        result[ori_f] = dist[:topn]
    return result  

def combine_result(back_result, profile_result):
    for ori in back_result.keys():
        back_result[ori].append(profile_result[ori])
    return back_result


def show_result(back_result, profile_result,  ori_path, back_path, profile_path):
    plt.figure()
    row_num = len(back_result)
    fig, axs = plt.subplots(row_num, 5, figsize=(25, 15))
    i = 1
    for r in back_result.keys():
        image = Image.open(os.path.join(ori_path, r))
        plt.subplot(row_num, 5, i)
        i = i + 1
        plt.imshow(image)
        for rr in back_result[r]:
            image = Image.open(os.path.join(back_path, rr[0]))
            plt.subplot(row_num, 5, i)
            i = i + 1
            plt.imshow(image)
        print(profile_path)
        print(profile_result[r])
        print(profile_result[r][0])
        print(os.path.join(profile_path, profile_result[r][0][0]))
        
        
        image = Image.open(os.path.join(profile_path, profile_result[r][0][0]))
        plt.subplot(row_num, 5, i)
        i = i + 1
        plt.imshow(image)

In [None]:
ori_path = "./测试2/原图_small"
back_path="./测试2/场景图/正面"
profile_path="./测试2/场景图/侧面"

ori_feats = extract_feature(ori_path, image_processor, model)
back_feats = extract_feature(imgback_path_dir, image_processor, model)
profile_feats = extract_feature(profile_path, image_processor, model)

back_res = get_result(ori_feats, back_feats, topn)
profile_res = get_result(profile_feats, back_feats, topn)

show_result(back_res, profile_res, ori_path, back_path, profile_path)