In [1]:
import os
import json
import base64
import re
from collections import OrderedDict
from typing import List, Dict, Any
from openai import OpenAI


dataset_name = "CIFAR10"
max_images = 8
MAX_LEN = 35
gpt_model = "gpt-4o"

root_dir = f"./data/selected_image/{dataset_name}"
concept_output_path = f"./data/generate_concept/concept/{dataset_name}_concepts_"+gpt_model+"_init.json"
ordered_output_path = f"./data/generate_concept/concept/{dataset_name}_concepts_"+gpt_model+"_ordered.json"
kv_output_path = f"./data/generate_concept/concept/{dataset_name}_concepts_"+gpt_model+"_final.json"
class_list_path = f"./data/classes_name/{dataset_name}_classes.txt"

client = OpenAI(
    base_url="",#your base_url
    api_key="",#your api_key
)


def clean_readable_name(class_name: str) -> str:
    name = re.sub(r'^\d+\.', '', class_name)
    return name.replace('_', ' ').strip()

def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

def build_prompt(readable_name: str, dataset_name: str) -> List[Dict]:
    return [{
        "type": "text",
        "text": 
        f"""
        You are shown several sample images from "{readable_name}".
        
        Examine their visual patterns and summarize the key features that commonly define this category in its image domain.
        Please focus on traits that are typically present across most samples, even if not in every image. 
        Describe each feature as a short, concrete visual concept without abstraction or inference. 
        Ensure that each concept is concise (≤30 characters) and output as a plain list without numbering or full sentences.
            
        Output format example:
        -two wings
        -a black cap
        -long gray runways
        -a baseball
        -a pattern of spots
        
        Now generate the feature list:
        """
    }]

def generate_concepts():
    all_concepts = {}
    class_names = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]

    for class_name in sorted(class_names):
        image_dir = os.path.join(root_dir, class_name)
        image_files = sorted(os.listdir(image_dir))[:max_images]

        if not image_files:
            continue

        readable_name = clean_readable_name(class_name)
        all_responses = []

        for split in [0, 1]: 
            split_len = len(image_files) // 2
            split_files = image_files[split * split_len : (split + 1) * split_len]

            message_content = build_prompt(readable_name, dataset_name)
            for image_file in split_files:
                image_path = os.path.join(image_dir, image_file)
                base64_img = encode_image(image_path)
                message_content.append({
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"}
                })

            try:
                completion = client.chat.completions.create(
                    model=gpt_model,
                    messages=[{"role": "user", "content": message_content}],
                )
                raw_output = completion.choices[0].message.content
                concept_list = [
                    line.strip("•-1234567890. ").strip()
                    for line in raw_output.strip().split('\n')
                    if line.strip()
                ]
                concept_list = [c for c in concept_list if len(c) <= MAX_LEN]
                all_responses.extend(concept_list)
            except Exception as e:
                print(f"[✗] Failed on {class_name} (split {split+1}): {e}")

        unique_concepts = sorted(set(all_responses))
        all_concepts[readable_name] = unique_concepts
        print(f"[✓] {readable_name}: {unique_concepts}")

    with open(concept_output_path, "w", encoding="utf-8") as f:
        json.dump(all_concepts, f, indent=2, ensure_ascii=False)
    print(f"\n Save in {concept_output_path}")

def reorder_by_class_list():
    with open(class_list_path, "r", encoding="utf-8") as f:
        ordered_classes = [line.strip().replace('_', ' ') for line in f if line.strip()]
    with open(concept_output_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    ordered_data = OrderedDict()
    for cls in ordered_classes:
        if cls in data:
            ordered_data[cls] = data[cls]

    with open(ordered_output_path, "w", encoding="utf-8") as f:
        json.dump(ordered_data, f, indent=2, ensure_ascii=False)
    print(f"[✓] JSON Save in ：{ordered_output_path}")

def convert_to_key_value_lists(input_path: str) -> Dict[str, List[Any]]:
    with open(input_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    value_list = []
    key_list = []
    for i, (key, values) in enumerate(data.items()):
        for val in values:
            value_list.append(val)
            key_list.append(i)
    return {"concepts": value_list, "concepts_to_class": key_list}

def save_key_value_format():
    output = convert_to_key_value_lists(ordered_output_path)
    with open(kv_output_path, 'w', encoding='utf-8') as f:
        json.dump(output, f, indent=2, ensure_ascii=False)
    print(f"[✓] Save in ：{kv_output_path}")
if __name__ == "__main__":
    generate_concepts()
    reorder_by_class_list()
    save_key_value_format()

[✓] airplane: ['cockpit windows', 'engines under wings', 'horizontal stabilizer', 'horizontal stabilizers', 'landing gear', 'long fuselage', 'rounded nose', 'tail fin', 'turbine engines', 'two wings', 'vertical tail', 'white body']
[✓] automobile: ['body paint', 'door handles', 'exhaust pipe', 'four wheels', 'front grille', 'glass windows', 'grille', 'headlights', 'license plate', 'logo badge', 'metallic body', 'rearview mirrors', 'rectangular shape', 'rubber tires', 'side doors', 'side mirrors', 'tires', 'wheels', 'windshield']
[✓] bird: ['Brownish or gray tones', 'a beak', 'a tail', 'feathered body', 'pointed beak', 'rounded body', 'small feathers', 'two legs', 'two wings']
[✓] cat: ['cat ears', 'four legs', 'furry body', 'furry coat', 'large eyes', 'long tail', 'round ears', 'small nose', 'unique fur patterns', 'vertical pupils', 'whiskers']
[✓] deer: ['antlers', 'brown fur', 'brown fur coat', 'four legs', 'grassy background', 'large branching antlers', 'pointed ears', 'prominent ey