In [21]:
import os
import torch
import pickle
from tqdm import *
import numpy as np
import pandas as pd
from PIL import Image
import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models
from eval_recall import get_performance
from utils.config import get_config
from utils.utils import set_seed
from time import time

class Raw2Vector:
    def __init__(self, image_model, text_model, args):
        self.args = args
        print("Available models:", available_models())
        # Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess = load_from_name("ViT-B-16", device=device, download_root='../BigDataSource/')
        self.model.eval()

    def image2tensor(self, image):
        image = image.unsqueeze(0).to(self.args.device)
        with torch.no_grad():
            image_features = self.model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features

    def text2tensor(self, text):
        # text = text.unsqueeze(0).to(self.args.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text)
            text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

    def retrieve(self, query, database):
        # query = query.unsqueeze(0).to(self.args.device)
        query, database = query.to(self.args.device), database.to(self.args.device)
        logit_scale = self.model.logit_scale.exp().to(self.args.device)
        logits_per_query = logit_scale * query @ database.t()
        probs = logits_per_query.softmax(dim=-1).cpu().detach().numpy()
        return probs


In [22]:


def high_speed_retreive(database, query, model, k):
    if isinstance(database, torch.Tensor):
        database = database.cpu()
    if isinstance(query, torch.Tensor):
        query = query.cpu()
    # Group truth 自己检索自己，除了自己以外的排名，故 + 1
    from modules.models.Retrieve import LSH, L2Index
    d = 512  # 向量维度
    if model == 'L2':
        print('执行L2检索部分')
        t1 = time()
        l2index = L2Index(k=k, d=d)
        topk_distence, topk_indices = l2index.search_topk_embeds(database, query)
        t2 = time()
        print(f'L2: {t2 - t1 : .2f}s')
    elif model == 'KNN':
        print('执行KNN检索部分')
        probs = transfer.retrieve(database, query)
        t1 = time()
        def numpy_top_k_indices(matrix, k, axis=1):
            if axis == 1:  # 处理每行
                k = min(k, matrix.shape[1])
                indices = np.argsort(matrix, axis=1)[:, -k:][:, ::-1]
            elif axis == 0:  # 处理每列
                k = min(k, matrix.shape[0])
                indices = np.argsort(matrix, axis=0)[-k:, :][::-1, :]
            return indices
        topk_indices = numpy_top_k_indices(probs, k, 1)
        t2 = time()
        print(f'KNN: {t2 - t1 : .2f}s')
    elif model == 'LSH':
        print('执行LSH检索部分')
        t1 = time()
        lsh = LSH(k=k, d=d, nbits=2048)
        topk_distence, topk_indices = lsh.search_topk_embeds(database, query)
        t2 = time()
        print(f'LSH: {t2 - t1 : .2f}s')
    return topk_indices


In [23]:

def get_all_image():
    # 本地化保存
    try:
        with open('../BigDataSource/Teddy2024/附件3/image_features_问题2.pkl', 'rb') as f:
            all_image_features = pickle.load(f)
            # torch.load(f, map_location='cpu')
    except:
        # 获得附件1数据集代码
        all_image = os.listdir('../BigDataSource/Teddy2024/附件3/ImageData')

        # 获得预训练模型
        transfer = Raw2Vector('ViT-B-16', '1', args)

        # 准备用多线程代码迅速获得所有图像的张量
        from concurrent.futures import ThreadPoolExecutor, as_completed
        def function(inputs):
            image_name = inputs
            image_address = '../BigDataSource/Teddy2024/附件3/ImageData/'
            file_name = image_address + image_name
            raw_image = Image.open(file_name)
            image_tensor = transfer.preprocess(raw_image)
            image_features = transfer.image2tensor(image_tensor)
            return image_name, image_features

        input_list = [image_name for image_name in all_image]
        all_image_features = []
        with ThreadPoolExecutor(max_workers=16) as executor:
            futures = [executor.submit(function, inputs) for inputs in input_list]
            for future in tqdm(as_completed(futures), total=len(all_image)):
                image_name, image_features = future.result()
                all_image_features.append([image_name, image_features.cpu()])
        with open('../BigDataSource/Teddy2024/附件3/image_features_问题2.pkl', 'wb') as f:
            pickle.dump(all_image_features, f)
        print('图像数据预训练并存储完毕!')
    return all_image_features

def get_all_text():
    # 本地化保存
    try:
        with open('../BigDataSource/Teddy2024/附件3/text_features_问题2.pkl', 'rb') as f:
            all_text_features = pickle.load(f)
    except:
        # 获得附件1数据集代码
        all_text = pd.read_csv('../BigDataSource/Teddy2024/附件3/word_data.csv').to_numpy()[:, 1]

        # 获得预训练模型
        transfer = Raw2Vector('ViT-B-16', '1', args)

        # 准备用多线程代码迅速获得所有文本的张量
        from concurrent.futures import ThreadPoolExecutor, as_completed
        def function(inputs):
            raw_text = inputs
            text_tensor = clip.tokenize(raw_text).to(args.device)
            text_features = transfer.text2tensor(text_tensor)
            return raw_text, text_features

        input_list = [raw_text for raw_text in all_text]
        all_text_features = []
        with ThreadPoolExecutor(max_workers=16) as executor:
            futures = [executor.submit(function, inputs) for inputs in input_list]
            for future in tqdm(as_completed(futures), total=len(all_text)):
                raw_text, text_features = future.result()
                all_text_features.append([raw_text, text_features.cpu()])
        with open('../BigDataSource/Teddy2024/附件3/text_features_问题2.pkl', 'wb') as f:
            pickle.dump(all_text_features, f)
        print('文本数据预训练并存储完毕!')
    return all_text_features

In [24]:
args = get_config()
set_seed(2024)
transfer = Raw2Vector('ViT-B-16', '1', args)


{'experiment': 'Run the experiment now!'}
Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
Loading vision model config from /Users/zengyuxiang/Documents/科研代码/CLIP_finetune/cn_clip/clip/model_configs/ViT-B-16.json
Loading text model config from /Users/zengyuxiang/Documents/科研代码/CLIP_finetune/cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json
Model info {'embed_dim': 512, 'image_resolution': 224, 'vision_layers': 12, 'vision_width': 768, 'vision_patch_size': 16, 'vocab_size': 21128, 'text_attention_probs_dropout_prob': 0.1, 'text_hidden_act': 'gelu', 'text_hidden_dropout_prob': 0.1, 'text_hidden_size': 768, 'text_initializer_range': 0.02, 'text_intermediate_size': 3072, 'text_max_position_embeddings': 512, 'text_num_attention_heads': 12, 'text_num_hidden_layers': 12, 'text_type_vocab_size': 2}


In [25]:
# 是否数据清洗
# preprocess() "ImageWordData_new"

# 首先获得所有的特征
all_text_features = get_all_text()
all_image_features = get_all_image()

In [26]:

# 修正序号
try:
    with open('../BigDataSource/Teddy2024/附件3/image_features_问题2_final.pkl', 'rb') as f:
        all_image_features = pickle.load(f)
    with open('../BigDataSource/Teddy2024/附件3/text_features_问题2_final.pkl', 'rb') as f:
        all_text_features = pickle.load(f)
except:
    
    # 文本部分
    text_data = pd.read_csv('../BigDataSource/Teddy2024/附件3/word_data.csv').to_numpy()
    all_text_idx = []
    for i in range(len(all_text_features)):
        all_text_idx.append(all_text_features[i][0])
    new_text_features = []
    for i in trange(len(text_data)):
        # print(text_data[i][1])
        text_idx = all_text_idx.index(text_data[i][1])
        new_text_features.append([text_data[i][0], text_data[i][1], all_text_features[text_idx][1]])
    all_text_features = new_text_features
    
    with open('../BigDataSource/Teddy2024/附件3/text_features_问题2_final.pkl', 'wb') as f:
        pickle.dump(all_text_features, f)
        
    # 视频部分
    image_data = pd.read_csv('../BigDataSource/Teddy2024/附件3/image_test.csv').to_numpy()
    all_image_idx = []
    new_image_features = []
    for i in range(len(all_image_features)):
        all_image_idx.append(all_image_features[i][0])
    for i in trange(len(image_data)):
        image_idx = all_image_idx.index(image_data[i][0])
        new_image_features.append([image_data[i][0], all_image_features[image_idx][1]])
    with open('../BigDataSource/Teddy2024/附件3/image_features_问题2_final.pkl', 'wb') as f:
        pickle.dump(all_image_features, f)
    all_image_features = new_image_features


100%|██████████| 50000/50000 [00:09<00:00, 5026.82it/s] 
100%|██████████| 5000/5000 [00:00<00:00, 29070.75it/s]


KeyboardInterrupt: 

In [28]:

# 直接获取张量
image_features = []
for i in range(len(all_image_features)):
    image_features.append(all_image_features[i][1])
image_features = torch.stack(image_features).squeeze(1)
print(image_features.shape)
print(all_image_features[0][0])
text_features = []
for i in range(len(all_text_features)):
    text_features.append(all_text_features[i][2])
text_features = torch.stack(text_features).squeeze(1)
print(text_features.shape)
print(all_text_features[0][1])


torch.Size([5000, 512])
Image14001013-8213.jpg
torch.Size([50000, 512])
洛阳楼盘 老城区楼盘 道北楼盘 保利<人名>


In [29]:
# 图像检索文本
print('-' * 80)
print('图像检索文本')
retrieve_method = 'L2'  # L2 KNN LSH
all_pred_rank = high_speed_retreive(image_features, text_features, retrieve_method, 100)

--------------------------------------------------------------------------------
图像检索文本
执行L2检索部分
L2:  2.12s


In [30]:
all_pred_rank[:, :5]

array([[2625, 3050, 1584, 3273, 2239],
       [ 626, 4374, 1031, 1530, 4409],
       [ 721, 4601,  397, 1725,  347],
       ...,
       [3437, 2979,  768, 1283, 4264],
       [4245, 3247, 1568, 3129, 2707],
       [2642, 1148, 1681,  379, 3527]])

In [31]:
for i in range(len(all_pred_rank)):
    print('-' * 80)
    print(all_text_features[i][0])
    for j in range(len(all_pred_rank[i])):
        print(all_image_features[all_pred_rank[i][j]][0])
        if j > 4:
            break
    if i >= 15:
        break

--------------------------------------------------------------------------------
Word-1000050001
Image14105001-4209.jpg
Image14105002-0128.jpg
Image14001015-7681.jpg
Image14105002-3035.jpg
Image14001016-6410.jpg
Image14105003-5600.jpg
--------------------------------------------------------------------------------
Word-1000050002
Image14001014-6338.jpg
Image14105003-8121.jpg
Image14001015-1205.jpg
Image14001015-7023.jpg
Image14105003-8764.jpg
Image14001015-4578.jpg
--------------------------------------------------------------------------------
Word-1000050003
Image14001014-7361.jpg
Image14105004-1311.jpg
Image14001014-3471.jpg
Image14001015-9660.jpg
Image14001014-2784.jpg
Image14001015-8968.jpg
--------------------------------------------------------------------------------
Word-1000050004
Image14105003-6963.jpg
Image14001015-5164.jpg
Image14001014-2703.jpg
Image14001014-6942.jpg
Image14001016-7163.jpg
Image14105003-4856.jpg
------------------------------------------------------------

In [33]:
ans = []
for i in range(len(all_pred_rank)):
    for j in range(len(all_pred_rank[i])):
        now = []
        now.append(all_text_features[i][0])
        now.append(j + 1)
        now.append(all_image_features[all_pred_rank[i][j]][0])
        ans.append(now)
        if j > 3:
            break
    # break
ans

[['Word-1000050001', 1, 'Image14105001-4209.jpg'],
 ['Word-1000050001', 2, 'Image14105002-0128.jpg'],
 ['Word-1000050001', 3, 'Image14001015-7681.jpg'],
 ['Word-1000050001', 4, 'Image14105002-3035.jpg'],
 ['Word-1000050001', 5, 'Image14001016-6410.jpg'],
 ['Word-1000050002', 1, 'Image14001014-6338.jpg'],
 ['Word-1000050002', 2, 'Image14105003-8121.jpg'],
 ['Word-1000050002', 3, 'Image14001015-1205.jpg'],
 ['Word-1000050002', 4, 'Image14001015-7023.jpg'],
 ['Word-1000050002', 5, 'Image14105003-8764.jpg'],
 ['Word-1000050003', 1, 'Image14001014-7361.jpg'],
 ['Word-1000050003', 2, 'Image14105004-1311.jpg'],
 ['Word-1000050003', 3, 'Image14001014-3471.jpg'],
 ['Word-1000050003', 4, 'Image14001015-9660.jpg'],
 ['Word-1000050003', 5, 'Image14001014-2784.jpg'],
 ['Word-1000050004', 1, 'Image14105003-6963.jpg'],
 ['Word-1000050004', 2, 'Image14001015-5164.jpg'],
 ['Word-1000050004', 3, 'Image14001014-2703.jpg'],
 ['Word-1000050004', 4, 'Image14001014-6942.jpg'],
 ['Word-1000050004', 5, 'Image1

In [None]:
# ans