In [57]:
# 为文件夹下的图片提取特征并建立索引
# 1.遍历图片
# 2.提取特征
# 3.建立索引
# 4.存储文件名与索引匹配关系
# 5.保存索引到磁盘

In [58]:
# 导入相关包
import torch
import torch.nn as nn
import timm
from torchvision import transforms as Transforms
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import faiss
import glob
import os
import tqdm
# 导入自定义特征提取类
from tools.feature_extract import SwinTransformer, Data_Processor


In [59]:
# 实例化
data_processor = Data_Processor(height=224, width=224)
model = SwinTransformer(num_features=512).cuda()

In [60]:
# model.eval()

In [61]:
# 加载权重
weight_path = 'weights/swin_base_patch4_window7_224.pth'
weight = torch.load(weight_path)
model.load_state_dict(weight['state_dict'], strict=True)

<All keys matched successfully>

In [62]:
# 提取特征函数
def getImgFeat(img_file):
    # PIL read image
    img = Image.open(img_file).convert("RGB")  # 读取图片，转换为RGB
    img = data_processor(img).cuda()  # 数据预处理
    feat = F.normalize(model(img), dim=1).cpu()  # 使用F.normalize对特征进行L2归一化
    return feat

In [63]:
# 遍历图片，提取特征
imgs_path = glob.glob("selected_imgs/*.jpg")
feats_list = []
names_list = []
for imgPath in tqdm.tqdm(imgs_path, desc="遍历图片"):
    with torch.no_grad():
        feat = getImgFeat(imgPath)
    name = imgPath.split(os.sep)[-1].split(".jpg")[0]
    feats_list.append(feat)
    names_list.append(name)
    
feats_list = torch.cat(feats_list, 0) # 将所有特征拼接起来
    # print(feat.shape)
    # print(name) 
    # break
# feats_list.shape

遍历图片: 100%|██████████| 18/18 [00:02<00:00,  6.56it/s]


In [64]:
# 将文件名列表存入numpy
names_list = np.array(names_list)
np.save("weights/names_list.npy", names_list)

In [65]:
# faiss数据库
# 创建索引
index = faiss.IndexFlatIP(512)
# feats_list转为np格式，再添加到faiss中
feats_list = np.array(feats_list)
index.add(feats_list)

In [66]:
# save index
faiss.write_index(index, 'weights/index_idols.index')

In [67]:
# read index
index = faiss.read_index('weights/index_idols.index')

In [68]:
# 简单测试一下
k = 4
D,I = index.search(feats_list[:4],k)
print(D)
print(I)

[[0.9999996  0.57087135 0.5184952  0.4959144 ]
 [1.0000002  0.7903343  0.7590165  0.68499094]
 [1.0000007  0.7808845  0.7530099  0.69316435]
 [0.99999946 0.70542264 0.68893516 0.6724826 ]]
[[ 0  6  5 13]
 [ 1 13  5  8]
 [ 2 13  8  5]
 [ 3 15  5  8]]


In [69]:
# 查询
query_img = imgs_path[0]
with torch.no_grad():
    feat_query = getImgFeat(query_img)
feat_query_np = feat_query.numpy()

In [70]:
feat_query_np.shape

(1, 512)

In [71]:
# 查询相似度在0.8以上的图片
threshold = 0.8
lims, D, I = index.range_search(feat_query_np, threshold) 
print(lims.shape, D.shape, I.shape)
print(lims)
print(I)
print(D)

(2,) (1,) (1,)
[0 1]
[0]
[0.91487265]


In [72]:
D,I = index.search(feat_query_np,k)
print(D)
print(I)

[[0.91487265 0.66761667 0.58117217 0.5723569 ]]
[[ 0  6  5 17]]
