In [15]:
from tair import Tair 
from tair.tairvector import DistanceMetric
from typing import List
import torch
from PIL import Image
import pylab
from matplotlib import pyplot as plt
import os
import cn_clip.clip as clip
from cn_clip.clip import available_models

In [18]:
model, preprocess = clip.load_from_name("RN50", device="cpu", download_root="./")
#model, preprocess = clip.load_from_name("RN50", device="cuda", download_root="./")
model.eval()

Loading vision model config from /root/miniconda3/lib/python3.10/site-packages/cn_clip/clip/model_configs/RN50.json
Loading text model config from /root/miniconda3/lib/python3.10/site-packages/cn_clip/clip/model_configs/RBT3-chinese.json
Model info {'embed_dim': 1024, 'image_resolution': 224, 'vision_layers': [3, 4, 6, 3], 'vision_width': 64, 'vision_patch_size': None, 'vocab_size': 21128, 'text_attention_probs_dropout_prob': 0.1, 'text_hidden_act': 'gelu', 'text_hidden_dropout_prob': 0.1, 'text_hidden_size': 768, 'text_initializer_range': 0.02, 'text_intermediate_size': 3072, 'text_max_position_embeddings': 512, 'text_num_attention_heads': 12, 'text_num_hidden_layers': 3, 'text_type_vocab_size': 2}


CLIP(
  (visual): ModifiedResNet(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn

In [16]:
def get_tair():
    """
    该方法用于连接Tair实例。
    * host：Tair实例连接地址。
    * port：Tair实例的端口号，默认为6379。
    * username：Tair实例的账户，可不填，使用默认账号登录。
    * password：Tair实例的密码。
    """
    client = Tair(
        host="120.27.213.45",
        port=6380,
    )
    return client
def create_index():
    # 创建存储图片embedding的索引
    ret = tair.tvs_get_index("index_images")
    if ret is None:
        tair.tvs_create_index("index_images", 1024, distance_type="IP",
                              index_type="HNSW")
    # 创建存储文字embedding的索引
    ret = tair.tvs_get_index("index_texts")
    if ret is None:
        tair.tvs_create_index("index_texts", 1024, distance_type="IP",
                              index_type="HNSW")
tair = get_tair()
# 分别创建存储图片和文本embedding的索引
create_index()

In [None]:
def extract_image_features(img_name):
    image_data = Image.open(img_name).convert("RGB")
    infer_data = preprocess(image_data)
#     infer_data = infer_data.unsqueeze(0).to("cuda")
    infer_data = infer_data.unsqueeze(0)
    with torch.no_grad():
        image_features = model.encode_image(infer_data)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features.cpu().numpy()[0]
def insert_images(image_dir):
    file_names = [f for f in os.listdir(image_dir) if (f.endswith('.jpg') or f.endswith('.jpeg'))]
    for file_name in file_names:
        image_feature = extract_image_features(image_dir + "/" + file_name)
        tair.tvs_hset("index_images", image_dir + "/" + file_name, image_feature)
insert_images("/root/caoduanxin/images")

In [None]:
def extract_text_features(text):
    #text_data = clip.tokenize([text]).to("cuda")
    text_data = clip.tokenize([text])
    with torch.no_grad():
        text_features = model.encode_text(text_data)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features.cpu().numpy()[0]  # [1, 1024]
def query_images_by_text(text, topK):
    text_feature = extract_text_features(text)
    result = tair.tvs_knnsearch("index_images", topK, text_feature)
    for k, s in result:
        print(f'key : {k}, distance : {s}')
        img = Image.open(k.decode('utf-8'))
        plt.imshow(img)
        pylab.show()
query_images_by_text("奔跑的狗", 3)

In [None]:
def upsert_text(text):
    text_features = extract_text_features(text)
    tair.tvs_hset("index_texts", text, text_features)


def query_texts_by_image(image_path, topK=3):
    image_feature = extract_image_features(image_path)
    result = tair.tvs_knnsearch("index_texts", topK, image_feature)
    for k, s in result:
        print(f'text : {str(k, encoding = "utf-8")}, distance : {s}')
# 将文本"狗"、"白色的狗"、"奔跑的白色的狗"插入到index_texts索引中
upsert_text("狗")
upsert_text("白色的狗")
upsert_text("奔跑的白色的狗")
# 以图搜文，指定图片路径，找到比较符合文本描述的文本
query_texts_by_image("/root/caoduanxin/images/boxer_18.jpg",3)