In [1]:
"""
init
"""
import clip
import json
import torch
import faiss
import numpy as np
from PIL import Image
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None

### init clip
clip_model, preprocess = clip.load("ViT-B/16", device="cuda")


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
[2024-03-14 01:13:12,456] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [None]:
"""
select the dataset and shot number
dataset: caltech101, eurosat, dtd, food101, oxford_flowers, oxford_pets, stanford_cars, sun397, ucf101, ...
shot number : 1, 2, 4, 8, 16
"""
shot_number = 4
dataset_name = 'caltech101'

In [2]:
def index_save(index_split, path):
    faiss.write_index(index_split, path)
    return

def get_number(str_data):
    str_data = str_data.replace("'", '"')
    data = json.loads(str_data)
    label_value = data['label']
    return label_value

def find_positions_and_paths(gt_labels, image_paths, n):
    positions = {}
    for idx, label in enumerate(gt_labels):
        if label not in positions:
            positions[label] = [idx]
        elif len(positions[label]) < n:
            positions[label].append(idx)
    return positions

def write_to_file(str_list, num_list, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        for s, n in zip(str_list, num_list):
            file.write(f"{s} {n}\n")

In [24]:
"""
get the classnames
"""
classnames_file_path = f"/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/benchmarks_test/{dataset_name}_database/classnames.txt"
with open(classnames_file_path, 'r') as file:
    classnames = file.readlines()
print(len(classnames))
classnames = [classname.strip() for classname in classnames]
print(classnames)

47
['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']


In [None]:

### save_file_path stores the selected k-shot images and their order in a .txt file, preparing for retrieving later
save_file_path = f"./database/{dataset_name}_database/{dataset_name}_{shot_number}_shot_database.txt"
### file_path refer to the file which includes the whole image path and labels of trainset. you can get this file in CLIP_Cls fold
file_path = f"/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/benchmarks_test/{dataset_name}_database/trainset.txt"

with open(file_path, 'r') as file:
    lines = file.readlines()
image_paths = [line.split()[0] for line in lines]

labels = [line.split(' ',1)[1] for line in lines]
gt_labels = [get_number(label) for label in labels]
# labels = [line.split()[1] for line in lines]
# gt_labels = [int(label) for label in labels]

print(len(image_paths))
print(len(gt_labels))
data_postition = find_positions_and_paths(gt_labels,image_paths,shot_number)
print(data_postition)

### normal
select_gt_labels = []
select_image_path = []
for i in range(len(classnames)):
    image_index = data_postition[i]
    for j in range(len(image_index)):
        select_gt_labels.append(i)
        select_image_path.append(image_paths[image_index[j]])
print(select_gt_labels)
print(select_image_path)
write_to_file(select_image_path, select_gt_labels, save_file_path)

In [26]:
"""
build memory
"""
file_path = f"./database/{dataset_name}_database/{dataset_name}_{shot_number}_shot_database.txt"
with open(file_path, 'r') as file:
    lines = file.readlines()
image_paths = [line.split()[0] for line in lines]
print(len(image_paths))

### index save path
index_img_save_path =  f"./database/{dataset_name}_database/{dataset_name}_{shot_number}_shot_img_index.index"

### build index
index_img = faiss.IndexHNSWFlat(512, 64, faiss.METRIC_INNER_PRODUCT)

### embedding images
embed_img = []
with torch.no_grad():
    for image_path in tqdm(image_paths,desc="Process:"):
        image = preprocess(Image.open(image_path)).unsqueeze(0).to("cuda")
        image_features = clip_model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        embed_img.append(image_features.cpu())
embed_img = [np.array(embed) for embed in embed_img]
embed_img = np.array(embed_img).squeeze()
index_img.add(embed_img)
print("Total number of indexes:", index_img.ntotal)

### save index
print("Saving index:")
index_save(index_img, index_img_save_path)
print("Done")

752


Process:: 100%|██████████| 752/752 [00:09<00:00, 76.32it/s]

Total number of indexes: 752
Saving index:
Done





In [12]:
### get CLIP+KNN prediction results
index_img_save_path = f"./database/{dataset_name}_database/{dataset_name}_{shot_number}_shot_img_index.index"
trainset_file_path = f"./database/{dataset_name}_database/{dataset_name}_{shot_number}_shot_database.txt"
predictions_save_path = f"/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/output/ZeroshotCLIP_topk/vit_b16/{dataset_name}/predictions_{shot_number}_shot_knn.pth"
index = faiss.read_index(index_img_save_path)
pth_file_path = f'/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/output/ZeroshotCLIP_topk/vit_b16/{dataset_name}/predictions.pth'
predictions = torch.load(pth_file_path)

with open(trainset_file_path, 'r') as file:
    lines = file.readlines()
    for prediction in tqdm(predictions,desc="Process:"):
        ### 解析pth文件，获取图片位置和原来的预测结果
        for item in prediction.values():
            pre_class = item['pred_class']
            # print(item['label'])
        for item in prediction.keys():
            test_img_path = item
    
        with torch.no_grad():
            image = preprocess(Image.open(test_img_path)).unsqueeze(0).to("cuda")
            # torch.Size([1, 512])
            image_features = clip_model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            image_features = np.array(image_features.cpu())
            distance, index_result = index.search(image_features, 100)
            
            labels = []
            for index_number in index_result[0]:
                parts = lines[index_number].strip().split(' ', 1)
                part1, part2 = parts
                labels.append(int(part2))
            labels = torch.tensor(labels)
            # print(labels)
        
            ### 修改pth文件
            for item in prediction.values():
                item['pred_class'] = labels
        
torch.save(predictions, predictions_save_path)

Process:: 100%|██████████| 3783/3783 [00:42<00:00, 89.89it/s] 
