In [5]:
import open_clip
import torch
from transformers import BertModel, BertTokenizer
class CLIP_Score:
    def __init__(
                    self, 
                    clip_model_path="IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese",
                ):
        self.text_encoder, self.text_tokenizer, self.clip_model, self.processor = self.load_clip_model(clip_model_path)
    
    def load_clip_model(self, model_path="IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese"):
        text_encoder = BertModel.from_pretrained(model_path).eval().cuda()
        text_tokenizer = BertTokenizer.from_pretrained(model_path)
        clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
        clip_model = clip_model.eval().cuda()
        return text_encoder, text_tokenizer, clip_model, processor


    def get_image_feature(self, images):
        # 此处返回图像的特征向量
        images = torch.stack([self.processor(image) for image in images]).cuda()
        with torch.no_grad():
            image_features = self.clip_model.encode_image(images)
            image_features /= image_features.norm(dim=1, keepdim=True)
        return image_features
    
    def get_text_feature(self, text):
        # 此处返回文本的特征向量
        text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].cuda()
        with torch.no_grad():
            text_features = self.text_encoder(text)[1]
            text_features /= text_features.norm(dim=1, keepdim=True)
        return text_features

    def calculate_score(self, features1, features2):
        # 此处2个特征向量的相似度，输入可以是 图片+文本、文本+文本、图片+图片。
        # 返回的是相似度矩阵
        score_each_pair =  features1 @ features2.t()
        return scores

In [6]:
demo = CLIP_Score()

In [10]:
from PIL import Image

image_path = '/cognitive_comp/chenweifeng/project/dl_scripts/text-image/data_filter/improved-aesthetic-predictor/1.jpg'
image_path2 = '/cognitive_comp/chenweifeng/project/dl_scripts/text-image/data_filter/Asthetics/mengna.jpg'
image_demo =  [Image.open(image_path).convert('RGB'), Image.open(image_path2).convert('RGB')]
text_demo = ['一副很美的画','港口小船', '大海', '沙漠', '蒙娜丽莎'] # 这里也可以只有一个文本，也就是query


image_feature = demo.get_image_feature(image_demo)  # 计算图片特征，传入图片列表，一般而言，可以在数据库保存这个东西，用于响应文本query
text_feature = demo.get_text_feature(text_demo) # 计算文本特征，传入文本列表
print(image_feature.shape, text_feature.shape)
print(image_feature @ text_feature.t())
similarity = demo.calculate_score(image_feature, text_feature)  # 计算相似度
print(similarity)

torch.Size([2, 768]) torch.Size([1, 768])
tensor([[-0.0492],
        [ 0.2003]], device='cuda:0')
[-0.04917662]
