In [1]:
import pickle

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn.functional as F
import torch
from transformers import CLIPProcessor, CLIPModel
import faiss
import os
import numpy as np
from tqdm import tqdm
import json

### CLIP for image embedding

In [2]:
class ImageDataset(Dataset):
    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        return self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)  # 删除多余的维度


In [3]:
from tqdm import tqdm
def encode_images(dataset, model):
    model.eval()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=300,
                                              pin_memory=True,
                                              num_workers=4,
                                              prefetch_factor=2,
                                              shuffle=False)
    all_features = []
    with torch.no_grad():  # 关闭梯度计算以提高性能并减少内存使用
        for images in tqdm(data_loader):
            images = images.to("cuda")  # 假设使用 CUDA
            features = model.get_image_features(images)#.pooler_output
            all_features.append(features.cpu())  # 将特征移动到 CPU 并保存

    return torch.cat(all_features)  # 合并所有特征为一个 Tensor


In [4]:
# 初始化 CLIP 处理器和模型
clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14-336')
clip_model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14-336').to("cuda")


  return self.fget.__get__(instance, owner)()


# VisDiag

In [5]:
# 加载图片路径

search_space_path = json.load(open('./playground/data/css_data/Search_Space_val_50k.json'))
merge_image_path = ['./playground/data/css_data/' + p for p in search_space_path]

merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_clip_visdial.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, clip_processor)
# 编码图片
encoded_images = encode_images(dataset, clip_model)

100%|██████████| 250/250 [22:42<00:00,  5.45s/it]


In [6]:
encoded_images = encoded_images.numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/clip_faiss_visdial.index')

# save the embedding 
with open('./checkpoints/clip_image_embedding_visdial.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)


# Flickr30K数据集

In [5]:
# 相同的代码在不同的数据集上运行
# 加载图片路径
image_fold = './playground/data/flickr30k/flickr30k_images'
image_paths = os.listdir(image_fold)#[:50]
merge_image_path = [os.path.join(image_fold, p) for p in image_paths]
merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_clip_flickr30k.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, clip_processor)
# 编码图片
encoded_images = encode_images(dataset, clip_model)

100%|██████████| 106/106 [14:25<00:00,  8.16s/it]


In [6]:
encoded_images = encoded_images.numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/clip_faiss_flickr30k.index')

# save the embedding 
with open('./checkpoints/clip_image_embedding_flickr30k.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

# MSCOCO数据集

In [5]:
# 相同的代码在不同的数据集上运行
# 加载图片路径
image_fold = './playground/data/mscoco/val2017'
image_paths = os.listdir(image_fold)#[:50]
merge_image_path = [os.path.join(image_fold, p) for p in image_paths]
merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_clip_mscoco.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, clip_processor)
# 编码图片
encoded_images = encode_images(dataset, clip_model)

100%|██████████| 25/25 [02:18<00:00,  5.56s/it]


In [6]:
encoded_images = encoded_images.numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/clip_faiss_mscoco.index')

# save the embedding 
with open('./checkpoints/clip_image_embedding_mscoco.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

# BLIP for image embedding 

In [2]:
### Encode image 

import io, requests
import torch 

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

import json
class ImageDataset(torch.utils.data.Dataset):
    """ Dataset class for the corpus images (the 50k potential candidates)"""
    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image_path = load_image(image_path)
        image = self.processor(images=image_path, return_tensors='pt')  # Load and prepare image
        return {'id': idx, 'pixel_values': image['pixel_values']}

def custom_collate_fn(batch):
    """
    Custom collate function to handle the processing of images.
    This function will combine the individual samples into a batch.
    """
    ids = [item['id'] for item in batch]
    pixel_values = torch.cat([item['pixel_values'] for item in batch])
    
    return {'id': ids, 'pixel_values': pixel_values}

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def encode_images(dataset, model):
    model.eval()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=250,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=False,
                                             prefetch_factor=2,
                                             collate_fn=custom_collate_fn
                                             )
    print("Preparing corpus (search space)...")
    corpus_vectors = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            image_embeds = model.vision_model(batch['pixel_values'].to(device))[0]
            batch_vectors = F.normalize(model.vision_proj(image_embeds[:, 0, :]), dim=-1)
            corpus_vectors.append(batch_vectors.cpu())
        corpus_vectors = torch.cat(corpus_vectors)
    return corpus_vectors

In [3]:
# 加载模型
from transformers import AutoProcessor, BlipForImageTextRetrieval
# model_id = "Salesforce/blip-image-captioning-base"
model_id = "Salesforce/blip-itm-base-coco"
blip_model = BlipForImageTextRetrieval.from_pretrained(model_id)
blip_model.to(device)
blip_processor = AutoProcessor.from_pretrained(model_id)

  return self.fget.__get__(instance, owner)()


## Flickr30K数据集

In [1]:
from transformers import AutoProcessor, BlipForImageTextRetrieval

blip_model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-flickr")
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-flickr")

config.json:   0%|          | 0.00/4.59k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/895M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/456 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [4]:
# 相同的代码在不同的数据集上运行
# 加载图片路径
image_fold = './playground/data/flickr30k/flickr30k_images'
image_paths = os.listdir(image_fold)#[:50]
merge_image_path = [os.path.join(image_fold, p) for p in image_paths]
merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_blip_flickr30k.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, blip_processor)
# 编码图片
encoded_images = encode_images(dataset, blip_model)

Preparing corpus (search space)...


100%|██████████| 128/128 [04:44<00:00,  2.22s/it]


In [7]:
len(dataset.__getitem__(0)['pixel_values'])

1

In [8]:
encoded_images = encoded_images.numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/blip_faiss_flickr30k.index')

# save the embedding 
with open('./checkpoints/blip_image_embedding_flickr30k.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

AttributeError: 'numpy.ndarray' object has no attribute 'numpy'

# MSCOCO数据集

In [9]:
# 相同的代码在不同的数据集上运行
# 加载图片路径
image_fold = './playground/data/mscoco/val2017'
image_paths = os.listdir(image_fold)#[:50]
merge_image_path = [os.path.join(image_fold, p) for p in image_paths]
merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_blip_mscoco.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, blip_processor)
# 编码图片
encoded_images = encode_images(dataset, blip_model)

Preparing corpus (search space)...


100%|██████████| 20/20 [00:48<00:00,  2.41s/it]


In [10]:
encoded_images = encoded_images.numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/blip_faiss_mscoco.index')

# save the embedding 
with open('./checkpoints/blip_image_embedding_mscoco.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

### Visdial Dataset (50k)

In [6]:
search_space_path = json.load(open('./playground/data/css_data/Search_Space_val_50k.json'))
merge_image_path = ['./playground/data/css_data/' + p for p in search_space_path]

merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_blip_visdial.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, blip_processor)
# 编码图片
encoded_images = encode_images(dataset, blip_model)

Preparing corpus (search space)...


100%|██████████| 200/200 [07:28<00:00,  2.24s/it]


In [7]:
encoded_images = encoded_images.numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/blip_faiss_visdial.index')

# save the embedding 
with open('./checkpoints/blip_image_embedding_visdial.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

### query expansion

In [4]:
# 相同的代码在不同的数据集上运行
# 加载图片路径
image_fold = './playground/data/flickr30k/flickr30k_images'
image_paths = os.listdir(image_fold)#[:50]
merge_image_path = [os.path.join(image_fold, p) for p in image_paths]
merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_blip_flickr30k.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, blip_processor)
# 编码图片
encoded_images = encode_images(dataset, blip_model)

Preparing corpus (search space)...


100%|██████████| 128/128 [04:32<00:00,  2.13s/it]


In [5]:
encoded_images = encoded_images.numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/blip_faiss_flickr30k.index')

# save the embedding 
with open('./checkpoints/blip_image_embedding_flickr30k.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

ChatIR

In [None]:
import torch 
#图片搜索
class ImageEmbedder:
    def __init__(self, model, preprocessor):
        """ model projects image to vector, processor load and prepare image to the model"""
        self.model = model
        self.processor = preprocessor

def BLIP_BASELINE():
    from torchvision import transforms
    from torchvision.transforms.functional import InterpolationMode
    import torch.nn.functional as F

    import sys
    sys.path.insert(0, './BLIP')
    from BLIP.models.blip_itm import blip_itm
    # load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_url = './BLIP/model_base_retrieval_coco.pth'
    model = blip_itm(pretrained=model_url, image_size=384, vit='base')
    # print(model)
    
    model = model.to(device).eval()

    # define Image Embedder (raw_image --> img_feature)
    transform_test = transforms.Compose([
        transforms.Resize((384, 384), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])

    def blip_project_img(image):
        embeds = model.visual_encoder(image)
        projection = model.vision_proj(embeds[:, 0, :])
        return F.normalize(projection, dim=-1)

    def blip_prep_image(path):
        raw = Image.open(path).convert('RGB')
        return transform_test(raw)

    image_embedder = ImageEmbedder(blip_project_img, lambda path: blip_prep_image(path))

    # define dialog encoder (dialog --> img_feature)
    def dialog_encoder(dialog):
        text = model.tokenizer(dialog, padding='longest', truncation=True,
                               max_length=200,
                               return_tensors="pt"
                               ).to(device)

        text_output = model.text_encoder(text.input_ids, attention_mask=text.attention_mask,
                                         return_dict=True, mode='text')

        shift = model.text_proj(text_output.last_hidden_state[:, 0, :])
        return F.normalize(shift, dim=-1)

    return dialog_encoder, image_embedder
dialog_encoder, image_embedder = BLIP_BASELINE()

BLIP_ITM(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
  

In [2]:
import json
class ImageDataset(torch.utils.data.Dataset):
    """ Dataset class for the corpus images (the 50k potential candidates)"""
    def __init__(self, image_paths, preprocessor):
        self.image_paths = image_paths
        self.preprocessor = preprocessor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = self.preprocessor(image_path)  # Load and prepare image
        return {'id': idx, 'image': image}

In [3]:
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def encode_images(dataset, image_embedder):
    # model.eval()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=50,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=False,
                                             prefetch_factor=2
                                             )
    print("Preparing corpus (search space)...")
    corpus_vectors = []
    # corpus_ids = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch_vectors = F.normalize(image_embedder.model(batch['image'].to(device)), dim=-1)
            corpus_vectors.append(batch_vectors)
            # corpus_ids.append(batch['id'].to(device))

        corpus_vectors = torch.cat(corpus_vectors)
    return corpus_vectors

In [14]:
import os 
import pickle
import torch.nn.functional as F 
from PIL import Image
# 加载图片路径
image_fold = './playground/data/mscoco/val2017'
image_paths = os.listdir(image_fold)#[:20]
merge_image_path = [os.path.join(image_fold, p) for p in image_paths]
merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image_blip_mscoco.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, image_embedder.processor)
# 编码图片
encoded_images = encode_images(dataset, image_embedder)

Preparing corpus (search space)...


100%|██████████| 100/100 [00:44<00:00,  2.25it/s]


In [15]:
import numpy as np
import faiss 
import pickle
encoded_images = encoded_images.cpu().numpy()
encoded_images /= np.linalg.norm(encoded_images, axis=1, keepdims=True) 

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/blip_faiss_mscoco.index')

# save the embedding 
with open('./checkpoints/blip_image_embedding_mscoco.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

### CharIR BLIP for image embedding


In [18]:

class ImageEmbedder:
    def __init__(self, model, preprocessor):
        """ model projects image to vector, processor load and prepare image to the model"""
        self.model = model
        self.processor = preprocessor

def BLIP_BASELINE():
    from torchvision import transforms
    from torchvision.transforms.functional import InterpolationMode

    import sys
    sys.path.insert(0, './BLIP')
    from BLIP.models.blip_itm import blip_itm
    # load model
    model = blip_itm(pretrained='./BLIP/chatir_weights.ckpt',  # Download from Google Drive, see README.md
                     med_config='BLIP/configs/med_config.json',
                     image_size=224,
                     vit='base')

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    # define Image Embedder (raw_image --> img_feature)
    transform_test = transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])

    def blip_project_img(image):
        embeds = model.visual_encoder(image)
        projection = model.vision_proj(embeds[:, 0, :])
        return F.normalize(projection, dim=-1)

    def blip_prep_image(path):
        raw = Image.open(path).convert('RGB')
        return transform_test(raw)

    image_embedder = ImageEmbedder(blip_project_img, lambda path: blip_prep_image(path))

    # define dialog encoder (dialog --> img_feature)
    def dialog_encoder(dialog):
        text = model.tokenizer(dialog, padding='longest', truncation=True,
                               max_length=200,
                               return_tensors="pt").to(device)

        text_output = model.text_encoder(text.input_ids, attention_mask=text.attention_mask,
                                         return_dict=True, mode='text')

        shift = model.text_proj(text_output.last_hidden_state[:, 0, :])
        return F.normalize(shift, dim=-1)

    return dialog_encoder, image_embedder

In [22]:
import json
class ImageDataset(torch.utils.data.Dataset):
    """ Dataset class for the corpus images (the 50k potential candidates)"""
    def __init__(self, image_paths, preprocessor):
        self.image_paths = image_paths
        self.preprocessor = preprocessor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = self.preprocessor(image_path)  # Load and prepare image
        return {'id': idx, 'image': image}

In [23]:
# dataset = ImageDataset(merge_image_path, image_embedder.processor)
#
# next(iter(dataset))

In [26]:
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def encode_images(dataset, image_embedder):
    # model.eval()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=500,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=False,
                                             prefetch_factor=2
                                             )
    print("Preparing corpus (search space)...")
    corpus_vectors = []
    # corpus_ids = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch_vectors = F.normalize(image_embedder.model(batch['image'].to(device)), dim=-1)
            corpus_vectors.append(batch_vectors)
            # corpus_ids.append(batch['id'].to(device))

        corpus_vectors = torch.cat(corpus_vectors)
    return corpus_vectors

In [27]:
# 初始化 CLIP 处理器和模型
dialog_encoder, image_embedder = BLIP_BASELINE()
# 加载图片路径
image_fold = './playground/data/css_data/unlabeled2017'
image_paths = os.listdir(image_fold)#[:20]
merge_image_path = [os.path.join(image_fold, p) for p in image_paths]
merge_image_path = sorted(merge_image_path)
id2image = dict(zip(range(len(merge_image_path)), merge_image_path))
with open('./checkpoints/id2image.pickle', 'wb') as f:
    pickle.dump(id2image, f)

dataset = ImageDataset(merge_image_path, image_embedder.processor)
# 编码图片
encoded_images = encode_images(dataset, image_embedder)

load checkpoint from ./BLIP/chatir_weights.ckpt
Preparing corpus (search space)...


100%|██████████| 247/247 [06:37<00:00,  1.61s/it]


In [None]:
encoded_images.shape

In [4]:

# 创建 Faiss 索引
dimension = encoded_images.shape[1]
faiss_index = faiss.IndexFlatIP(dimension) # 
faiss_index.add(encoded_images)
faiss.write_index(faiss_index,'./checkpoints/blip_faiss.index')

In [5]:
with open('checkpoints/blip_image_embedding.pickle', 'wb') as f:
    pickle.dump(encoded_images, f)

### CLIP for Text Embedding

In [5]:
import torch
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, texts, processor):
        """
        Args:
            texts (list of str): List of text strings.
            processor (transformers processor): Processor to tokenize the text.
        """
        self.texts = texts
        self.processor = processor

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]  # Get the text at the provided index
        inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True, max_length=100)
        return inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0)

def text_data_collector(batch):
    input_ids, attention_masks = zip(*batch)
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,batch_first=True,padding_value=0)
    attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True,padding_value=0)
    return input_ids, attention_masks


In [6]:
def encode_text(dataset, data_collector, model):
    model.eval()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=50,
                                              pin_memory=True,
                                              num_workers=4,
                                              prefetch_factor=2,
                                              shuffle=False,
                                              collate_fn=data_collector)
    all_features = []
    with torch.no_grad():  # 关闭梯度计算以提高性能并减少内存使用
        for batch in data_loader:
            input_ids, attention_masks = batch
            input_ids = input_ids.to("cuda")
            attention_masks = attention_masks.to("cuda")
            features = model.get_text_features(input_ids=input_ids, attention_mask=attention_masks)#.pooler_output
            print(features)
            all_features.append(features.cpu())  # 将特征移动到 CPU 并保存
    return torch.cat(all_features)  # 合并所有特征为一个 Tensor



In [7]:
# 初始化 CLIP 处理器和模型
processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14-336')
clip_model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14-336').to("cuda")

texts = ["[PAD]", "This is a text dataset example.", "We are learning about AI!"]
# Create an instance of the dataset
text_dataset = TextDataset(texts, processor)

# 编码图片
encoded_ = encode_text(text_dataset, text_data_collector, clip_model)

tensor([[ 3.7448e-01,  6.2039e-01,  2.1363e-01,  ...,  2.8336e-02,
          2.2532e-01,  1.7659e-01],
        [ 5.6634e-01,  9.4504e-01,  2.3726e-01,  ..., -1.8392e-01,
          2.4562e-01,  1.6248e-01],
        [-2.7422e-01, -4.1648e-01,  3.1572e-01,  ..., -7.5832e-02,
         -3.1761e-01, -2.7838e-04]], device='cuda:0')


In [12]:
def retrieve_topk_images(query,
                         topk=10,
                         faiss_model=None,
                         clip_model=None,
                         id2image=None,
                         processor=None, ):
    text_dataset = TextDataset(query, processor)
    query_vec = encode_text(text_dataset, text_data_collector, clip_model)
    distance, indices = faiss_model.search(np.array([query_vec]), topk)
    return [id2image.get(i, None) for i in indices]

In [11]:
query  = ["[PAD]", "This is a text dataset example.", "We are learning about AI!"]
faiss_model = faiss.read_index('./checkpoints/clip_faiss.index')
retrieve_topk_images(query,
                     topk=5,
                     faiss_model=faiss_model,
                     clip_model=None,
                     id2image=None,
                     processor=None, )

NameError: name 'TextDataset' is not defined