In [None]:
import pytorch_lightning as pl
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch.nn as nn

class AestheticsMLP(pl.LightningModule):
    # 美学判别器是基于CLIP的基础上接了一个MLP
    def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
        super().__init__()
        self.input_size = input_size
        self.xcol = xcol
        self.ycol = ycol
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            #nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(64, 16),
            #nn.ReLU(),

            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.layers(x)

    def training_step(self, batch, batch_idx):
            x = batch[self.xcol]
            y = batch[self.ycol].reshape(-1, 1)
            x_hat = self.layers(x)
            loss = F.mse_loss(x_hat, y)
            return loss
    
    def validation_step(self, batch, batch_idx):
        x = batch[self.xcol]
        y = batch[self.ycol].reshape(-1, 1)
        x_hat = self.layers(x)
        loss = F.mse_loss(x_hat, y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
import open_clip
import torch
from transformers import BertModel, BertTokenizer
class Score:
    def __init__(
                    self, 
                    clip_model_path="IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese",
                    aesthetics_model_path="/cognitive_comp/chenweifeng/project/dl_scripts/text-image/data_filter_system/ava+logos-l14-linearMSE.pth",
                ):
        
        self.text_encoder, self.text_tokenizer, self.clip_model, self.processor = self.load_clip_model(clip_model_path)
        self.aesthetics_model = self.init_aesthetics_model(aesthetics_model_path)

    def load_clip_model(self, model_path="IDEA-CCNL/Taiyi-CLIP-RoBERTa-102M-ViT-L-Chinese"):
        text_encoder = BertModel.from_pretrained(model_path).eval().cuda()
        text_tokenizer = BertTokenizer.from_pretrained(model_path)
        clip_model, _, processor = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
        clip_model = clip_model.eval().cuda()
        return text_encoder, text_tokenizer, clip_model, processor

    def init_aesthetics_model(self, aesthetics_model_path):
        # 此处初始化美学模型
        aesthetics_model = AestheticsMLP(768)
        aesthetics_model.load_state_dict(torch.load(aesthetics_model_path))
        aesthetics_model.eval().cuda()
        print("aesthetics model loaded")
        return aesthetics_model

    def get_image_feature(self, images):
        # 此处返回图像的特征向量
        images = torch.stack([self.processor(image) for image in images]).cuda()
        with torch.no_grad():
            image_features = self.clip_model.encode_image(images)
            image_features /= image_features.norm(dim=1, keepdim=True)
        return image_features

    def get_aesthetics_score(self, features):
        # 此处返回美学分数，传入的是CLIP的feature, 先计算get_image_feature在传入此函数~(模型是ViT-L-14)
        with torch.no_grad():
            scores = self.aesthetics_model(features)
            scores = scores[:, 0].detach().cpu().numpy()
        return scores

    def get_text_feature(self, text):
        # 此处返回文本的特征向量
        text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].cuda()
        with torch.no_grad():
            text_features = self.text_encoder(text)[1]
            text_features /= text_features.norm(dim=1, keepdim=True)
        return text_features

    def calculate_score(self, features1, features2):
        # 此处2个特征向量的相似度，输入可以是 图片+文本、文本+文本、图片+图片。
        # 返回的是相似度矩阵
        score_each_pair =  features1 @ features2.t()
        return score_each_pair

In [None]:
demo = Score()

In [None]:
from PIL import Image

image_path = '/cognitive_comp/chenweifeng/project/dl_scripts/text-image/data_filter_system/demo_images/1.jpg'
image_path2 = '/cognitive_comp/chenweifeng/project/dl_scripts/text-image/data_filter_system/demo_images/mengna.jpg'
image_demo =  [Image.open(image_path).convert('RGB'), Image.open(image_path2).convert('RGB')]
text_demo = ['一副很美的画','港口小船', '大海', '沙漠', '蒙娜丽莎'] # 这里也可以只有一个文本，也就是query


image_feature = demo.get_image_feature(image_demo)  # 计算图片特征，传入图片列表，一般而言，可以在数据库保存这个东西，用于响应文本query
text_feature = demo.get_text_feature(text_demo) # 计算文本特征，传入文本列表
# print(image_feature.shape, text_feature.shape)
# print(image_feature @ text_feature.t())
similarity = demo.calculate_score(image_feature, text_feature)  # 计算相似度
print(similarity)
aes_score = demo.get_aesthetics_score(image_feature)  # 计算美学分数
print(aes_score)