In [3]:
"""
初始化设置
"""
import os
import re
import json
import clip
import faiss
import torch
import random
import numpy as np
from tqdm import tqdm
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

### 初始化一个clip模型
clip_model, preprocess = clip.load("ViT-B/16", device="cuda")

In [4]:
trainset_file_path = "./fgvc_database/trainset.txt"
trainset = []
with open(trainset_file_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        split_line = line.strip().split(' ', 1)
        part1, part2 = split_line
        # 使用正则表达式查找数字
        match = re.search(r'\d+', part2)
        label = int(match.group())
        trainset.append([part1, label])
print(len(trainset))

classnames_file_path = "./fgvc_database/classnames.txt"
with open(classnames_file_path, 'r') as file:
    classnames = file.readlines()
print(len(classnames))

3334
100


In [79]:
"""
使用random的方法产生的10000条数据，Yes5000条，No5000条
"""
json_data = []
count = 0 
while count<5000:
    sampled_elements = random.sample(trainset, 2)
    sampled_name = [classnames[int(sampled_elements[0][1])].strip(), classnames[int(sampled_elements[1][1])].strip()]
    if sampled_elements[0][1] != sampled_elements[1][1]:
        my_dict = {
        'id': i,
        'image': [sampled_elements[0][0], sampled_elements[1][0]],
        'conversations': [{'from': 'user', 'value': '以下两张图片中的飞机，是否属于同一型号<Img index=0><image></Img> <Img index=1><image></Img>请直接给出Yes或No的回答，再分析原因'},{'from': 'assistant', 'value': 'No. <InsertImg index=0>是xxx1型号，而 <InsertImg index=1>是xxx2型号'}]
        }
        my_dict["conversations"][1]["value"] = my_dict["conversations"][1]["value"].replace('xxx1', sampled_name[0])
        my_dict["conversations"][1]["value"] = my_dict["conversations"][1]["value"].replace('xxx2', sampled_name[1])
        json_data.append(my_dict)
        count+=1
count = 0
while count<5000:
    sampled_elements = random.sample(trainset, 2)
    sampled_name = [classnames[int(sampled_elements[0][1])].strip(), classnames[int(sampled_elements[1][1])].strip()]
    if sampled_elements[0][1] == sampled_elements[1][1]:
        my_dict = {
        'id': i,
        'image': [sampled_elements[0][0], sampled_elements[1][0]],
        'conversations': [{'from': 'user', 'value': '以下两张图片中的飞机，是否属于同一型号<Img index=0><image></Img> <Img index=1><image></Img>请直接给出Yes或No的回答，再分析原因'},{'from': 'assistant', 'value': 'Yes. <InsertImg index=0>和<InsertImg index=1>都是xxx型号'}]
        }
        my_dict["conversations"][1]["value"] = my_dict["conversations"][1]["value"].replace('xxx', sampled_name[0])
        json_data.append(my_dict)
        count+=1
        
with open('./my_list.json', 'w', encoding='utf-8') as file:
    json.dump(json_data, file, ensure_ascii=False)

In [5]:
"""
使用clip的方法产生的10000条数据
"""
### 初始化
index_img_save_path = "./fgvc_database/index_img.index"
index = faiss.read_index(index_img_save_path)

json_data = []
yes_number = 0
no_number = 0
### 遍历数据集
for train_data in tqdm(trainset):
    
    ### 获取每一个数据的图片位置和label
    train_data_label = int(train_data[1])
    train_data_image = train_data[0]
    # print(train_data_label)
    # print(train_data_image)

    ### 匹配临近的图片
    with torch.no_grad():
        image = preprocess(Image.open(train_data[0])).unsqueeze(0).to("cuda")
        # torch.Size([1, 512])
        image_features = clip_model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        image_features = np.array(image_features.cpu())
        distance, index_result = index.search(image_features, 100)
        # print("Distance:")
        # print(distance)
        # print(index_result[0])
    
        matched_labels = []
        matched_images = []
        with open(trainset_file_path, 'r') as file:
            lines = file.readlines()
            for index_number in index_result[0]:
                parts = lines[index_number].strip().split(' ', 1)
                # print(parts)
                part1, part2 = parts
                match = re.search(r":\s*(\d+)", part2)
                number = int(match.group(1))
                matched_labels.append(number)
                matched_images.append(part1)
        # print(matched_labels) 
        # print(matched_images)

    ### 遍历匹配的图片。构造数据
    count = 0 
    for i in range(99):
        if count<40:
            if train_data_label != matched_labels[i+1]:
                train_data_name = classnames[train_data_label].strip()
                matched_data_name = classnames[matched_labels[i+1]].strip()
                temp_train_data_image = train_data_image.replace("/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/data/","/mnt/petrelfs/share_data/zangyuhang/img-cls/")
                temp_matched_images = matched_images[i+1].replace("/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/data/","/mnt/petrelfs/share_data/zangyuhang/img-cls/")
                my_dict = {
                'id': int(len(json_data)),
                'image': [temp_train_data_image, temp_matched_images],
                'conversations': [{'from': 'user', 'value': '以下两张图片中的飞机，是否属于同一型号<Img index=0><image></Img> <Img index=1><image></Img>请直接给出Yes或No的回答，再分析原因'},{'from': 'assistant', 'value': 'No. <InsertImg index=0>是xxx1型号，而 <InsertImg index=1>是xxx2型号'}]
                }
                my_dict["conversations"][1]["value"] = my_dict["conversations"][1]["value"].replace('xxx1', train_data_name)
                my_dict["conversations"][1]["value"] = my_dict["conversations"][1]["value"].replace('xxx2', matched_data_name)
                json_data.append(my_dict)
                count+=1
                no_number+=1

    count = 0 
    for i in range(99):
        if count<40:
            if train_data_label == matched_labels[i+1]:
                train_data_name = classnames[train_data_label].strip()
                matched_data_name = classnames[matched_labels[i+1]].strip()
                temp_train_data_image = train_data_image.replace("/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/data/","/mnt/petrelfs/share_data/zangyuhang/img-cls/")
                temp_matched_images = matched_images[i+1].replace("/mnt/petrelfs/liuziyu/LLM_Memory/SimplyRetrieve/CLIP-Cls/data/","/mnt/petrelfs/share_data/zangyuhang/img-cls/")
                my_dict = {
                'id': int(len(json_data)),
                'image': [temp_train_data_image, temp_matched_images],
                'conversations': [{'from': 'user', 'value': '以下两张图片中的飞机，是否属于同一型号<Img index=0><image></Img> <Img index=1><image></Img>请直接给出Yes或No的回答，再分析原因'},{'from': 'assistant', 'value': 'Yes. <InsertImg index=0>和<InsertImg index=1>都是xxx型号'}]
                }
                my_dict["conversations"][1]["value"] = my_dict["conversations"][1]["value"].replace('xxx', train_data_name)
                json_data.append(my_dict)
                count+=1
                yes_number+=1
    # print(json_data)
print(yes_number)
print(no_number)

with open('./clip_genetator.json', 'w', encoding='utf-8') as file:
    json.dump(json_data, file, ensure_ascii=False)

100%|██████████| 3334/3334 [01:25<00:00, 39.12it/s]


39282
133360


In [6]:
"""
确保图像对独一无二，防止出现相同的数据
"""
with open('./clip_genetator.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

unique_image_pairs = set()
new_data = []

# 检查并移除非唯一图像对的元素
for item in data:
    image_pair = tuple(sorted(item['image']))
    if image_pair not in unique_image_pairs:
        unique_image_pairs.add(image_pair)
        new_data.append(item)

# 重新命名id
for i in range(len(new_data)):
    new_data[i]["id"] = i

print(len(new_data))

# 将处理后的数据写入新文件
with open("cleaned.json", 'w', encoding='utf-8') as file:
    json.dump(new_data, file, indent=4, ensure_ascii=False)

130575


In [32]:
"""
使用clip的方法,生成图文排序的数据集
"""
def custom_sort(item, match):
    return 0 if item == match else 1
    
### 初始化
index_img_save_path = "./fgvc_database/index_img.index"
index = faiss.read_index(index_img_save_path)

json_data = []
### 遍历数据集
for train_data in tqdm(trainset):
    
    ### 获取每一个数据的图片位置和label
    train_data_label = int(train_data[1])
    train_data_image = train_data[0]
    train_data_name = classnames[train_data_label].strip()
    # print(train_data_label)
    # print(train_data_image)
    # print(train_data_name)

    ### 匹配临近的图片
    with torch.no_grad():
        image = preprocess(Image.open(train_data[0])).unsqueeze(0).to("cuda")
        # torch.Size([1, 512])
        image_features = clip_model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        image_features = np.array(image_features.cpu())
        distance, index_result = index.search(image_features, 20)
        # print("Distance:")
        # print(distance)
        # print(index_result[0])
    
        matched_labels = []
        matched_images = []
        matched_names = []
        with open(trainset_file_path, 'r') as file:
            lines = file.readlines()
            for index_number in index_result[0]:
                parts = lines[index_number].strip().split(' ', 1)
                # print(parts)
                part1, part2 = parts
                match = re.search(r":\s*(\d+)", part2)
                number = int(match.group(1))
                matched_labels.append(number)
                matched_images.append(part1)
                matched_names.append(classnames[number].strip())
        # print(matched_labels) 
        # print(matched_images)
        # print(matched_names)

        ### 五个五个拆分,且确保每一个lsit中存在正确答案
        grouped_lists = []
        grouped_list = [matched_names[i:i + 5] for i in range(0, len(matched_names) - 4)]
        for i in range(len(grouped_list)):
            if train_data_name in grouped_list[i]:
                grouped_lists.append(grouped_list[i])
        # print(grouped_lists)

        ### 重新排序，生成正确答案
        sorted_lists = []
        for i in range(len(grouped_lists)):
            sorted_list = sorted(grouped_lists[i], key=lambda item: custom_sort(item, train_data_name))
            sorted_lists.append(sorted_list)
        # print(sorted_lists)

        ### 转化为string
        grouped_lists = [str(sublist) for sublist in grouped_lists]
        sorted_lists = [str(sublist) for sublist in sorted_lists]
        
        for i in range(len(grouped_lists)):
            my_dict = {
                    'id': int(len(json_data)),
                    'image': [train_data_image],
                    'conversations': [{'from': 'user', 'value': 'Here is a image:<Img index=1><image></Img>. Please play the role of a aircraft classification expert, and sort the provided categories from high to low according to the top 5 similarity with the input image. Here are the optional categories:{names}.'.format(names=grouped_lists[i])},{'from': 'assistant', 'value': '{result}'.format(result=sorted_lists[i])}]
                    }
            json_data.append(my_dict)

        # print(json_data)
print(len(json_data))
with open('./classnames_rerank.json', 'w', encoding='utf-8') as file:
    json.dump(json_data, file, indent=4, ensure_ascii=False)

100%|██████████| 3334/3334 [01:21<00:00, 40.69it/s]


30647
