### This notebook uses only the **rendered images** for captioning

In [None]:
# Import necessary libraries
import os
import jsonlines
import json
import openai
import csv
import base64
import tiktoken
import time
from time import time
from openai import OpenAI

# Set OpenAI API key
# Load API key from environment variable for security
# Set the OPENAI_API_KEY environment variable before running this notebook:
# export OPENAI_API_KEY='your-api-key-here'
openai.api_key = os.getenv("OPENAI_API_KEY")

if not openai.api_key:
    raise ValueError("OPENAI_API_KEY environment variable not set. Please set it before running this notebook.")

In [None]:
# Define directories and file paths
step_dir = "./dataset/abccad/step_under500/abc_0001_step_v00_under500"  
# step_dir = "./dataset/abccad/test_purpose_only/step"   #"./data/abccad_test/0001_step_500"
image_dir = "./dataset/abccad/step_under500_image/abc_0001_step_v00_under500_image"    #"./data/abccad_test/0001_step_500_images"
# image_dir = "./dataset/abccad/test_purpose_only/image"  
output_csv = "cad_captions.csv"


In [None]:
# Encode image as base64
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

In [None]:
# Test image encoding
image_path = "./dataset/abccad/step_under500_image/abc_0004_step_v00_under500_image/00040014_dfb67f543711434fb5751d3f_step_003.jpeg"
base64_image = encode_image(image_path)
print(base64_image)

In [None]:
# Extract DATA section from STEP file
def extract_data_section(step_path):
    step_content = []
    data_section_found = False
    with open(step_path, "r") as f:
        for line in f:
            if "DATA;" in line:
                data_section_found = True
            if data_section_found:
                step_content.append(line)
    return "".join(step_content)

In [None]:
# Calculate token count using tiktoken
def calculate_tokens(messages, encoder):
    total_tokens = 0
    for message in messages:
        if isinstance(message["content"], list):
            for content_item in message["content"]:
                if "text" in content_item:
                    total_tokens += len(encoder.encode(content_item["text"]))
        else:
            total_tokens += len(encoder.encode(message["content"]))
    return total_tokens

In [None]:
def get_caption(text_message, image_message):
    client = OpenAI()
    response = client.chat.completions.create(
        model="gpt-4o",
        temperature=0.2,
        max_tokens=300,
        messages=[
            {
                "role": "system",
                "content": "You are a helpful assistant specialized in CAD analysis."
            },
            {
                "role": "user", "content": [text_message, image_message]
            }
        ]
    )

    return response.choices[0].message.content

### DO NOT USE THIS PROMPT

In [None]:
# Process directory of STEP files and save results
def process_directory(step_dir, image_dir, output_csv):
    encoder = tiktoken.get_encoding("o200k_base")    # gpt-4o uses o200k_base

    existing_model_ids = set()
    if os.path.exists(output_csv):
        with open(output_csv, mode="r", newline="") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                existing_model_ids.add(row["model_id"])

    # Count total number of STEP files to process
    total_files = 0
    for subfolder in os.listdir(step_dir):
        subfolder_path = os.path.join(step_dir, subfolder)
        if os.path.isdir(subfolder_path):
            for step_file in os.listdir(subfolder_path):
                if step_file.endswith(".step"):
                    model_id = step_file.split('_')[0]
                    if model_id not in existing_model_ids:
                        total_files += 1

    processed_count = 0
    
    with open(output_csv, mode="a", newline="") as csvfile:
        fieldnames = ["model_id", "isDescribable", "description", "token_count"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        # Write the header only if the file is newly created
        if not existing_model_ids:
            writer.writeheader()

        for subfolder in os.listdir(step_dir):
            subfolder_path = os.path.join(step_dir, subfolder)
            if os.path.isdir(subfolder_path):
                for step_file in os.listdir(subfolder_path):
                    if step_file.endswith(".step"):
                        model_id = step_file.split('_')[0]

                        # Skip processing if the model_id already exists in the output CSV
                        if model_id in existing_model_ids:
                            print(f"Skipping {step_file}, already processed.")
                            continue
                        
                        step_path = os.path.join(subfolder_path, step_file)
                        image_path = os.path.join(image_dir, f"{step_file.split('.')[0]}.jpeg")

                        if not os.path.exists(image_path):
                            print(f"Warning: Missing image for {step_file}")
                            continue

                        step_content = extract_data_section(step_path)
                        base64_image = encode_image(image_path)

                        text_message = {
                            "type": "text",
                            "text": f"""

                            You are provided with a multi-view image of a single CAD model. Each image shows the same object from a different perspective. 

                            Please follow these precise instructions:

                            1. **Determine describability**: Assess whether this CAD model (as a whole) can be described using a short natural language caption. Do not describe individual views or interpret them as separate objects. Even if the model consists of multiple connected or adjacent parts, treat it as one entity.

                            2. **Important note on views**:
                            - The **top row** contains standard orthographic projections (e.g., front, top, side views). These provide reliable geometric clues such as length, thickness, angles, and symmetry. Use them as the primary reference for structure.
                            - The lower rows show perspective views from various angles to supplement spatial understanding.
                            - Do **not** mistakenly infer that multiple views indicate multiple separate objects. They are different renderings of the **same** CAD model.
                            - Carefully observe small but important features such as through holes, indentations, internal cavities, or partial hollowness. These details matter in the description.

                            3. **Perspective reasoning**: 
                            - Infer depth, flatness, holes, or slopes using the combination of orthographic and perspective views. Be cautious not to misinterpret shading or foreshortening.
                            - Do not invent geometric features based on misleading perspective. For example, do not describe a cube as having a "triangular" or "diamond-shaped" face just because of viewing angle.

                            4. **Output requirements**:
                            - If the object has a clear, describable geometric or functional shape (e.g., “a cylindrical rod with a flange”,"a water cup"), set `isDescribable: true` and provide a concise description (1–2 lines max).
                            - If the object appears abstract, highly irregular, or does not admit any concise description, set `isDescribable: false`.

                            5.Common mistakes to avoid:
                            - Misinterpreting perspective views as separate parts or separate objects.
                            - Misclassifying flat rectangular plates or cubes as having diamond, triangular, pyramid, or prism-like shapes.
                            - Ignoring holes or cutouts in discs, rods, or tapered parts.
                            - Incorrectly assuming a part is completely solid or completely hollow without checking views.

                            6. **Return the result strictly as a JSON object**, with no commentary or explanation. For example:
                            ```json
                            {{
                            "isDescribable": true,
                            "description": "A cylindrical rod with a conical tip and a flanged base."
                            }}
                            """
                        }

                        image_message = {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{base64_image}"
                            }
                        }

                        messages = [
                            {
                                "role": "system",
                                "content": "You are a helpful assistant specialized in CAD analysis."
                            },
                            {
                                "role": "user", "content": [text_message, image_message]
                            }
                        ]

                        token_count = calculate_tokens(messages, encoder)
                        start_time = time()
                        caption = get_caption(text_message, image_message)
                        end_time = time()

                        print(f"Processed {step_file} in {end_time - start_time:.2f} seconds")
                        #print(caption)

                        # Parse response
                        try:
                            # Remove extraneous characters before parsing
                            caption = caption.strip()
                            if caption.startswith("```json"):
                                caption = caption.lstrip("```json").rstrip("```")
                            caption_data = json.loads(caption)
                            is_describable = caption_data.get("isDescribable", "N/A")
                            description = caption_data.get("description", "N/A")
                        except json.JSONDecodeError as e:
                            print(f"JSON parsing error for {step_file}: {e}\nRaw response:\n{caption}")
                            is_describable = "Error"
                            description = "Error"


                        writer.writerow({
                            "model_id": model_id,
                            "isDescribable": is_describable,
                            "description": description,
                            "token_count": token_count
                        })

                        processed_count += 1

                        # Print progress every 100 processed files
                        if processed_count % 50 == 0 or processed_count == total_files:
                            print(f"Processed: {processed_count}/{total_files}")

### DO NOT USE THIS PROMPT

In [None]:
def process_directory_from_images(image_dir, output_csv):
    encoder = tiktoken.get_encoding("o200k_base")  # gpt-4o uses o200k_base

    existing_model_ids = set()
    if os.path.exists(output_csv):
        with open(output_csv, mode="r", newline="") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                existing_model_ids.add(row["model_id"])

    image_files = []
    for f in os.listdir(image_dir):
        if f.endswith(".jpeg") and f.split("_")[0] not in existing_model_ids:
            image_files.append(f)
        else:
            print(f"Skipping {f}, already processed.")
    total_files = len(image_files)
    processed_count = 0

    with open(output_csv, mode="a", newline="") as csvfile:
        fieldnames = ["model_id", "isDescribable", "description", "token_count"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        if not existing_model_ids:
            writer.writeheader()

        for image_file in image_files:
            model_id = image_file.split("_")[0]
            image_path = os.path.join(image_dir, image_file)
            # if model_id in existing_model_ids:
            #                 print(f"Skipping {step_file}, already processed.")
            #                 continue

            if not os.path.exists(image_path):
                print(f"Warning: Missing image {image_file}")
                continue

            base64_image = encode_image(image_path)

            text_message = {
                "type": "text",
                "text": f"""
                You are provided with a multi-view image of a single CAD model. Each image shows the same object from a different perspective. 

                Please follow these precise instructions:

                1. **Determine describability**: Assess whether this CAD model (as a whole) can be described using a short natural language caption. Do not describe individual views or interpret them as separate objects. Even if the model consists of multiple connected or adjacent parts, treat it as one entity.
                2. **Important note on views**:
                - The **top row** contains standard orthographic projections (e.g., front, top, side views). These provide reliable geometric clues such as length, thickness, angles, and symmetry. Use them as the primary reference for structure.
                - The lower rows show perspective views from various angles to supplement spatial understanding.
                - Do **not** mistakenly infer that multiple views indicate multiple separate objects. They are different renderings of the **same** CAD model.
                - Carefully observe small but important features such as through holes, indentations, internal cavities, or partial hollowness. These details matter in the description.

                3. **Perspective reasoning**: 
                - Infer depth, flatness, holes, or slopes using the combination of orthographic and perspective views. Be cautious not to misinterpret shading or foreshortening.
                - Do not invent geometric features based on misleading perspective. For example, do not describe a cube as having a "triangular" or "diamond-shaped" face just because of viewing angle.

                4. **Output requirements**:
                - If the object has a clear, describable geometric or functional shape (e.g., “a cylindrical rod with a flange”,"a water cup"), set `isDescribable: true` and provide a concise description (1–2 lines max).
                - If the object appears abstract, highly irregular, or does not admit any concise description, set `isDescribable: false`.

                5. **Real-world class guessing (encouraged)**:
                - When the object **clearly resembles a real-world item** — such as a cup, plate, wheel, bolt, gear, bracket, or hinge — you are encouraged to **label it accordingly** in the description, as long as the geometry supports that interpretation.
                - Be conservative and only assign class-like names (e.g., “a hinge”, “a pipe connector”) when the overall shape and structure strongly suggest the function or category.
                - If unsure, default to a purely geometric description.

                6. Common mistakes to avoid:
                - Misinterpreting perspective views as separate parts or separate objects.
                - Misclassifying flat rectangular plates, rectangular blocks, or cubes as having diamond, triangular, or pyramid-like shapes.
                - Ignoring holes or cutouts in discs, rods, or tapered parts.
                - Incorrectly assuming a part is completely solid or completely hollow without checking views.
                - Failing to distinguish between the object (colored in brown) and the background (colored in grey), treating a thin ring as a disc or plate.
                - Mistaking uniformly thick plates, rods, or disks for unevenly thick ones.
                
                7. **Return the result strictly as a JSON object**, with no commentary or explanation. For example:
                ```json
                {{
                "isDescribable": true,
                "description": "A cylindrical rod with a conical tip and a flanged base."
                }}
                """
            }

            image_message = {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                }
            }

            messages = [
                {
                    "role": "system",
                    "content": "You are a helpful assistant specialized in CAD analysis."
                },
                {
                    "role": "user",
                    "content": [text_message, image_message]
                }
            ]

            token_count = calculate_tokens(messages, encoder)
            start_time = time()
            caption = get_caption(text_message, image_message)
            end_time = time()

            print(f"Processed {image_file} in {end_time - start_time:.2f} seconds")

            try:
                caption = caption.strip()
                if caption.startswith("```json"):
                    caption = caption.lstrip("```json").rstrip("```")
                caption_data = json.loads(caption)
                is_describable = caption_data.get("isDescribable", "N/A")
                description = caption_data.get("description", "N/A")
            except json.JSONDecodeError as e:
                print(f"JSON parsing error for {image_file}: {e}\nRaw response:\n{caption}")
                is_describable = "Error"
                description = "Error"

            writer.writerow({
                "model_id": model_id,
                "isDescribable": is_describable,
                "description": description,
                "token_count": token_count
            })

            processed_count += 1
            if processed_count % 50 == 0 or processed_count == total_files:
                print(f"Processed: {processed_count}/{total_files}")

### USE THIS PROMPT

In [None]:
def process_directory_from_images(image_dir, output_csv):
    encoder = tiktoken.get_encoding("o200k_base")  # gpt-4o uses o200k_base

    existing_model_ids = set()
    if os.path.exists(output_csv):
        with open(output_csv, mode="r", newline="") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                existing_model_ids.add(row["model_id"])

    image_files = []
    for f in os.listdir(image_dir):
        if f.endswith(".jpeg") and f.split("_")[0] not in existing_model_ids:
            image_files.append(f)
        else:
            print(f"Skipping {f}, already processed.")
    total_files = len(image_files)
    processed_count = 0

    with open(output_csv, mode="a", newline="") as csvfile:
        fieldnames = ["model_id", "isDescribable", "description", "token_count"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        if not existing_model_ids:
            writer.writeheader()

        for image_file in image_files:
            model_id = image_file.split("_")[0]
            image_path = os.path.join(image_dir, image_file)
            # if model_id in existing_model_ids:
            #                 print(f"Skipping {step_file}, already processed.")
            #                 continue

            if not os.path.exists(image_path):
                print(f"Warning: Missing image {image_file}")
                continue

            base64_image = encode_image(image_path)

            text_message = {
                "type": "text",
                "text": f"""
                You are provided with a multi-view image of a single CAD model. Each image shows the same object from a different perspective. 

                Please follow these precise instructions:

                1. Determine describability: Assess whether this CAD model (as a whole) can be described using a short natural language caption. Do not describe individual views or interpret them as separate objects. Even if the model consists of multiple connected or adjacent parts, treat it as one entity.
                2. Important note on views:
                - There are 9 views in total, arranged in 3 rows and 3 columns.
                - The top row contains standard orthographic projections (e.g., front, top, side views). These provide reliable geometric clues such as length, thickness, angles, and symmetry. Use them as the primary reference for structure.
                - The lower rows show perspective views from various angles to supplement spatial understanding.
                - Do not mistakenly infer that multiple views indicate multiple separate objects. They are different renderings of the same CAD model.
                - Carefully observe small but important features such as through holes, indentations, internal cavities, or partial hollowness. These details matter in the description.

                3. Perspective reasoning: 
                - Infer depth, flatness, holes, or slopes using the combination of orthographic and perspective views. Be cautious not to misinterpret shading or foreshortening.
                - Do not invent geometric features based on misleading perspective. For example, do not describe a cube as having a "triangular" or "diamond-shaped" face just because of viewing angle.
                - If it can be judged as a cube from the top row images, simply state it "A cube".
                - If it can be judged as a rectangular block from the top row images, simply state it "A rectangular block".
                - Be coutious about using the word "tapered", double check the top row images to see if it is really a tapered part or not.

                4. Output requirements:
                - If the object has a clear, describable geometric or functional shape (e.g., “a cylindrical rod with a flange”,"a water cup"), set `isDescribable: true` and provide a concise description (1-2 lines max).
                - If the object appears abstract, highly irregular, or does not admit any concise description, set `isDescribable: false`.

                5. Real-world class guessing (encouraged):
                - When the object clearly resembles a real-world item — such as a cup, plate, wheel, bolt, gear, bracket, or hinge — you are encouraged to label it accordingly in the description, as long as the geometry supports that interpretation.
                - Assign class-like names (e.g., “a hinge”, “a pipe connector”) when the overall shape and structure strongly suggest the function or category.
                - If unsure, default to a purely geometric description.

                6. Common mistakes to avoid:
                - Misinterpreting perspective views as separate parts or separate objects.
                - Misclassifying flat rectangular plates, rectangular blocks, or cubes as having diamond, triangular, or pyramid-like shapes.
                - Ignoring holes or cutouts in discs, rods, or tapered parts.
                - Incorrectly assuming a part is completely solid or completely hollow without checking views.
                - Failing to distinguish between the object (colored in brown) and the background (colored in grey), treating a thin ring as a disc or plate.
                - Mistaking uniformly thick plates, rods, or disks for unevenly thick ones.
                - Do not describe a cube as "A cube with a triangular prism on top" or "A cube with a pyramid on top of one face."
                - Do not describe a rectangular block as "A rectangular block with a triangular prism on top" or "A rectangular block with a pyramid on top of one face."
                
                7. **Return the result strictly as a JSON object**, with no commentary or explanation. For example:
                ```json
                {{
                "isDescribable": true,
                "description": "A cylindrical rod with a conical tip and a flanged base."
                }}
                """
            }

            image_message = {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                }
            }

            messages = [
                {
                    "role": "system",
                    "content": "You are a helpful assistant specialized in CAD analysis."
                },
                {
                    "role": "user",
                    "content": [text_message, image_message]
                }
            ]

            token_count = calculate_tokens(messages, encoder)
            start_time = time()
            caption = get_caption(text_message, image_message)
            end_time = time()

            print(f"Processed {image_file} in {end_time - start_time:.2f} seconds")

            try:
                caption = caption.strip()
                if caption.startswith("```json"):
                    caption = caption.lstrip("```json").rstrip("```")
                caption_data = json.loads(caption)
                is_describable = caption_data.get("isDescribable", "N/A")
                description = caption_data.get("description", "N/A")
            except json.JSONDecodeError as e:
                print(f"JSON parsing error for {image_file}: {e}\nRaw response:\n{caption}")
                is_describable = "Error"
                description = "Error"

            writer.writerow({
                "model_id": model_id,
                "isDescribable": is_describable,
                "description": description,
                "token_count": token_count
            })

            processed_count += 1
            if processed_count % 50 == 0 or processed_count == total_files:
                print(f"Processed: {processed_count}/{total_files}")

In [None]:
# process_directory(step_dir, image_dir, output_csv)
process_directory_from_images(image_dir, output_csv)
print(f"Results saved to {output_csv}")

In [None]:
# caption_system_prompt = '''
# Your goal is to generate short, descriptive captions for images of items.
# You will be provided with an item image and the name of that item and you will output a caption that captures the most important information about the item.
# If there are multiple items depicted, refer to the name provided to understand which item you should describe.
# Your generated caption should be short (1 sentence), and include only the most important information about the item.
# The most important information could be: the type of item, the style (if mentioned), the material or color if especially relevant and/or any distinctive features.
# Keep it short and to the point.
# '''
from openai import OpenAI
client = OpenAI()
step_path = "./dataset/abccad/step_under500/abc_0002_step_v00_under500/00020045/00020045_d23cdf27f9ab48f99dcbdaa1_step_002.step"
step_content = extract_data_section(step_path)

image_path = "./dataset/abccad/step_under500_image/abc_0002_step_v00_under500_image/00020045_d23cdf27f9ab48f99dcbdaa1_step_002.jpeg"
base64_image = encode_image(image_path)

text_message = {
    "type": "text",
    "text": f"""
    You are provided with the following DATA section of a CAD model in STEP format:
    {step_content}

    You are also provided with a multi-view image that shows the rendered model. 

    Please do the following:
    1. Determine whether this CAD model depicts an object that can be briefly described in natural language (e.g., "A cylindrical rod with a base").
    2. If it can be described, provide a concise sentence (one or two lines) that accurately captures the object's main geometric or functional characteristics.
    3. If it appears completely random, meaningless, or not describable in simple terms, respond with "cannot be described".

    Return your final answer in the format:
        {{
            "isDescribable": true/false,
            "description": "..."
        }}
    """
}

image_message = {
    "type": "image_url",
    "image_url": {
    "url": f"data:image/jpeg;base64,{base64_image}"
    }
}

def get_caption(text_message, image_message):
    response = client.chat.completions.create(
    model="gpt-4o",
    temperature=0.2,
    max_tokens=300,
    messages=[
        {
            "role": "system",
            "content": "You are a helpful assistant specialized in CAD analysis."
        },
        {
            "role": "user", "content": [text_message, image_message]
        }
    ]
    )

    return response.choices[0].message.content

caption = get_caption(text_message, image_message)
print(f"CAPTION: {caption}\n\n")