# Crafter

In [67]:
import os
import csv
import json
import torch
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from generate_embeddings import generate_embeddings
from openai import OpenAI

In [68]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## 1. Help Function

### 1.1 Model Calling

In [69]:
def kimi(question, history=None):
	client = OpenAI(
		api_key = "sk-DhVW0coT3r2Kj4hpqX9Uy7p0NRgj6ETBG9WfB1cHQLoDrOgO",
		base_url = "https://api.moonshot.cn/v1",
	)

	messages = [
			{
			"role": "user",
			"content": question
			}
		]
	
	if history is not None:
		messages = history + messages

	completion = client.chat.completions.create(
		model = "moonshot-v1-8k",
		messages = messages,
		temperature = 0.3,
	)

	history = messages + [
		{
			"role": "assistant",
			"content": completion.choices[0].message.content
		}
	]

	return completion.choices[0].message.content, history

In [70]:
def deepseek(question, history=None, model="r1"):
	client = OpenAI(
		base_url="https://openrouter.ai/api/v1",
		api_key="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",
	)

	if model == "v3":
		model_name = "deepseek/deepseek-v3-base:free"
	elif model == "r1":
		model_name = "deepseek/deepseek-r1:free"
	else:
		return "Invalid model name"

	messages = [
			{
			"role": "user",
			"content": question
			}
		]
	
	if history is not None:
		messages = history + messages

	completion = client.chat.completions.create(
		extra_headers={},
		extra_body={},
		model=model_name,
		messages=messages
	)

	history = messages + [
		{
			"role": "assistant",
			"content": completion.choices[0].message.content
		}
	]

	return completion.choices[0].message.content, history

### 1.2 JSON Operate

In [71]:
def load_json(filepath, dump=False):
	with open(filepath, 'r', encoding='utf-8') as file:
		cad = json.load(file)
	if dump:
		cad = json.dumps(cad, indent=4)
	return cad

In [None]:
def available_path(dirpath='./outputs/'):
    idx = 0
    while os.path.exists(dirpath + f"generate{str(idx)}.json"):
        idx += 1
    return dirpath + f"generate{str(idx)}.json"

In [99]:
def save_json(json_dic, filepath):
    with open(filepath, "w", encoding='utf-8') as file:
    	json.dump(json_dic, file, indent=4)

In [95]:
def json_generate(prompt, model_call=kimi, limit=3, save=True):
	count = 0
	history = []
	while count < limit:
		gene_s, history = model_call(prompt, history)
		if gene_s.count('```') >= 2:
			start = gene_s.find('```json\n')
			gene_s = gene_s[start+8:]
			end = gene_s.find('```')
			gene_json = json.loads(gene_s[:end])
			print("Succeed")
			if save:
				savepath = available_path('./inter_result/')
				save_json(gene_json, savepath)
				print(f"save to {savepath}")
			return gene_json, True
		else:
			count += 1
			print("Go on...")
			print(gene_s)
			prompt = "继续生成"
	print("Fail")
	return gene_s, False

### 1.3 CLIP

In [75]:
def load_captions():
    captions = {}
    pth_paths = {}  # 独立字典
    captions_file = 'data.csv'
    if not os.path.exists(captions_file):
        return captions, pth_paths
    with open(captions_file, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            if len(row) >= 3:
                image_path = row[0].strip('"')
                captions[image_path] = row[1]
                pth_paths[image_path] = row[2].strip()  # 独立存储
    return captions, pth_paths  # 返回两个独立字典

In [76]:
model_name = "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
clip_model = CLIPModel.from_pretrained(model_name).to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained(model_name)
captions_dict, pth_paths_dict = load_captions()

In [77]:
'''
这里是原来关键词查top图片的那个
参数说明：
    keyword是一段文本， 
    n_images是要输出top多少的，
    use_captions是要不要用caption
返回值：
    top_images:是一个数组里面有n_images个字典，每个字典有：
        "filename":图片文件名
        "similarity": 相似度数值
        "path":图片路径
        "caption": 描述字段
        "pth_path": CAD文件路径
    error就是error
'''
def top_images_view(keyword, n_images, use_captions):
    error = None
    top_images = []
    use_captions = False

    if not keyword:
        error = "关键词为空"
    elif not n_images:
        error = "无图片数量"
    else:
        try:
            if n_images < 1 or n_images > 1000:
                error = "数量不在1-1000"
            else:
                image_paths, _, similarities = generate_embeddings(
                    clip_model, clip_processor, keyword, captions_dict, use_captions
                )
                results = [
                    {
                    "filename": os.path.basename(path),
                    "similarity": sim,
                    "path": path,
                    "caption": captions_dict.get(path, "No caption") if use_captions else "Captions disabled",
                    "pth_path": pth_paths_dict.get(path, "No pth path")  # 始终包含
                    }
                    for path, sim in zip(image_paths, similarities)
                ]
                top_images = sorted(results, key=lambda x: x["similarity"], reverse=True)[:n_images]
        except ValueError:
            error = "非有效的正整数"
        except Exception as e:
            error = f"计算相似度失败: {e}"

    return top_images, error

## 2. Ways

### 2.1 Direct Generate

In [78]:
def whole_generate(desire):
	example = load_json('./examples/simple.json', dump=True)
	prompt = example + f"请按照如上给出的minimal_json格式生成一个{desire}的CAD程序。" + \
		"严格按照示例，主体包含'coordinate_system'，'extrusion'，以及'sketch'字段。" + \
		"'sketch'字段中用face描述面，face中用loop描述构成面的线，loop中只支持arc, line, circle三种。" + \
		"这里的点全部用'Start Point'，'Mid Point'，'End Point'字段描述，line需要start和end，arc需要start, mid end，对应的是B-ref曲线。" + \
		"circle需要'Center'和'Radius'两个子字段"
	return json_generate(prompt)

### 2.2 Component & Join

#### 2.2.1 Divide

In [79]:
def gene_division(desire):
	example = load_json('./examples/gene_params.json', dump=True)
	prompt = example + f"请按照上述的json格式将一个'{desire}'的复杂图形拆分成小组件的组合。" +\
		"整体结构与各个字段严格按照示例，整体分为name, parts和positional relationship三个部分。" +\
		"parts中按照part_1/2/3进行编号，每个子结构中有id, description, position等参数" +\
		"id是唯一标识符，postion长为3，描述该物体相对于坐标原点的偏移量" +\
		"positional relationship为一个列表，每个元素是一个字典，描述的是各个物体的相对位置关系。" +\
		"from和to对应的是物体的id，表示这里描述的是to物体相对于from物体的位置。relationship分为三个维度，每个维度通过一个长为2的字符串列表描述" +\
		"x轴由left/right/align表示相对位置; y轴由front/back/center/none表示; z轴由up/down/align表示. 第二个维度用intersect和seperate表示是否相交" +\
		f"请按照上述的json格式将一个'{desire}'的复杂图形拆分成小组件的组合。"
	return json_generate(prompt)

In [80]:
def prefabs_division(desire):
	example = load_json('./examples/prefab_params.json', dump=True)
	prompt = example + f"请按照上述的json格式将一个'{desire}'的复杂图形拆分成小组件的组合。" +\
		"整体结构与各个字段严格按照示例，整体分为name, parts和positional relationship三个部分。" +\
		"parts中按照part_1/2/3进行编号，每个子结构中有id, type, description, params, position等参数" +\
		"id是唯一标识符，type只有cuboid和cylinder两种，前者params长为3，分别描述长宽高；后者params长为2，对应半径与高度。postion长为3，描述该物体相对于坐标原点的偏移量" +\
		"positional relationship为一个列表，每个元素是一个字典，描述的是各个物体的相对位置关系。" +\
		"from和to对应的是物体的id，表示这里描述的是to物体相对于from物体的位置。relationship分为三个维度，每个维度通过一个长为2的字符串列表描述" +\
		"x轴由left/right/center/none表示相对位置; y轴由front/back/center/none表示; z轴由up/down/center/none表示. 第二个维度用intersect和seperate表示是否相交" +\
		f"请按照上述的json格式将一个'{desire}'的复杂图形拆分成小组件的组合。"
	return json_generate(prompt)

#### 2.2.2 Get Components

In [81]:
def component_generate(shape):
	example = load_json('./basics/cuboid.json', dump=True)
	prompt = example + f"请按照上述minimal_json的格式生成一个'{shape}'的CAD程序。" + \
		"严格按照示例，包含'coordinate_system'，'extrusion'，以及'sketch'字段。" + \
		"'sketch'字段中用face描述面，face中用loop描述构成面的线，loop中只支持arc, line, circle三种。" + \
		"这里的点全部用'Start Point'，'Mid Point'，'End Point'字段描述，line需要start和end，arc需要start, mid end，对应的是B-ref曲线。" + \
		"circle需要'Center'和'Radius'两个子字段" +\
		f"请按照上述minimal_json的格式生成一个'{shape}'的CAD程序。"
	return json_generate(prompt)

In [82]:
def component_clip(shape, num=1, caption=False):
	top_images, _ = top_images_view(shape, num, caption)
	cads = []
	for item in top_images:
		uid = item['filename'][:8]
		filepath = f"../../dataset/{uid[:4]}/{uid}/minimal_json/{uid}_merged_vlm.json"
		cad = load_json(filepath)
		cads.append(cad)
		# TODO
		break
	return cads[0], True

In [83]:
def cuboid_set(cuboid, params):
	l = float(params[0])
	w = float(params[1])
	h = float(params[2])

	loop = cuboid['sketch']['face_1']['loop_1']

	line_1 = loop['line_1']
	line_1['Start Point'] = [0.0, 0.0]
	line_1['End Point'] = [l, 0.0]

	line_2 = loop['line_2']
	line_2['Start Point'] = [l, 0.0]
	line_2['End Point'] = [l, w]

	line_3 = loop['line_3']
	line_3['Start Point'] = [l, w]
	line_3['End Point'] = [0.0, w]

	line_4 = loop['line_4']
	line_4['Start Point'] = [0.0, w]
	line_4['End Point'] = [0.0, 0.0]

	cuboid['extrusion']['extrude_depth_towards_normal'] = h

	des = cuboid['description']
	des['length'] = l
	des['width'] = w
	des['height'] = h
	des['shape'] = f"A rectangular prism with length {l}, width {w} and height {h}"

In [84]:
def cylinder_set(cylinder, params):
	radius = float(params[0])
	h = float(params[1])

	circle = cylinder['sketch']['face_1']['loop_1']['circle_1']
	circle['Center'] = [radius, radius]
	circle['Radius'] = radius

	cylinder['extrusion']['extrude_depth_towards_normal'] = h

	des = cylinder['description']
	des['length'] = radius
	des['width'] = radius
	des['height'] = h
	des['shape'] = f'A circular cylinder with diameter {radius} and height {h}'

In [85]:
def component_prefab(type, params):
	cad_json = load_json(f"./basics/{type}.json")
	if type == "cuboid":
		cuboid_set(cad_json, params)		
	elif type == "cylinder":
		cylinder_set(cad_json, params)
	return cad_json

#### 2.2.3 Merge

In [93]:
def get_shape(obj_dic):
    length = obj_dic['description']['length']
    width = obj_dic['description']['width']
    height = obj_dic['description']['height']
    return [length, width, height]

def axis_compute(base_val, ex_val, base_size, ex_size, relative):
	if relative[0] in ['left', 'front', 'up']:
		res = base_val
		if relative[1] == "intersect":
			res -= ex_size / 2
		else:
			res -= ex_size
	elif relative[0] in ['right', 'back', 'down']:
		res = base_val + base_size
		if relative[1] == "intersect":
			res -= ex_size / 2
	elif relative[0] == "center":
		res = base_val + base_size / 2 - ex_size / 2
	else:
		res = ex_val
	return res

def relative_compute(base_pos, ex_pos, base_shape, ex_shape, relative):
	ex_pos[0] = axis_compute(base_pos[0], ex_pos[0], base_shape[0], ex_shape[0], relative['x'])
	ex_pos[1] = axis_compute(base_pos[1], ex_pos[1], base_shape[1], ex_shape[1], relative['y'])
	ex_pos[2] = axis_compute(base_pos[2], ex_pos[2], base_shape[2], ex_shape[2], relative['z'])

def relative_set(final_dic, rela_lis):
	parts = final_dic['parts']
	for rela in rela_lis:
		base = f"part_{rela['from']}"
		extend_ids = rela['to']
		for extend_id in extend_ids:
			extend = f"part_{extend_id}"
			if base not in parts or extend not in parts:
				continue
			base_dic = parts[base]
			extend_dic = parts[extend]
			base_pos = base_dic['coordinate_system']['Translation Vector']
			ex_pos = extend_dic['coordinate_system']['Translation Vector']
			base_shape = get_shape(base_dic)
			extend_shape = get_shape(extend_dic)
			relative_compute(base_pos, ex_pos, base_shape, extend_shape, rela['relationship'])

In [87]:
def merge_gene(parts_dic, merge):
	final_dic = {
		"final_shape": parts_dic['name'],
		"parts": {}
	}
	final_parts = final_dic['parts']
	parts = parts_dic['parts']
	part_count = 0
	for part_id in parts:
		part = parts[part_id]
		cad_dic, res = component_generate(part['description'])
		if not res:
			continue
		if merge == "params":
			position = part['position']
			cad_dic['coordinate_system']['Translation Vector'] = position
		final_parts[f'part_{part_count + 1}'] = cad_dic
		part_count += 1
	if merge == "hand":
		# convert 'to' to list
		for item in parts_dic['positional relationship']:
			item['to'] = [item['to']]
		relative_set(final_dic, parts_dic['positional relationship'])
	return final_dic

In [88]:
def merge_clip(parts_dic, merge):
	final_dic = {
		"final_shape": parts_dic['name'],
		"parts": {}
	}
	final_parts = final_dic['parts']
	parts = parts_dic['parts']
	part_count = 0
	sub_map = {}
	for part_id in parts:
		part = parts[part_id]
		top_images, _ = top_images_view(part['description'], 1, False)
		uid = top_images[0]['filename'][:8]
		filepath = f"../../dataset/{uid[:4]}/{uid}/minimal_json/{uid}_merged_vlm.json"
		print(part['description'], filepath)
		part_dic = load_json(filepath)
		if merge == "params":
			position = part['position']
		sub_map[part['id']] = []
		for sub_part_id in part_dic['parts']:
			sub_part = part_dic['parts'][sub_part_id]
			if merge == "params":
				vec = sub_part['coordinate_system']['Translation Vector']
				vec += position
			final_parts[f'part_{part_count+1}'] = sub_part
			part_count += 1
			sub_map[part['id']].append(part_count)
	if merge == "hand":
		# convert 'to' to list
		for item in parts_dic['positional relationship']:
			item['from'] = sub_map[item['from']][0]
			item['to'] = sub_map[item['to']]
		relative_set(final_dic, parts_dic['positional relationship'])
	return final_dic

In [89]:
def merge_prefabs(parts_dic, merge="params"):
	final_dic = {
		"final_shape": parts_dic['name'],
		"parts": {}
	}
	final_parts = final_dic['parts']
	parts = parts_dic['parts']
	for part_id in parts:
		part = parts[part_id]
		part_dic = component_prefab(part['type'], part['params'])
		if merge == "params":
			position = part['position']
			part_dic['coordinate_system']['Translation Vector'] = position
		final_parts[part_id] = part_dic
	if merge == "hand":
		# convert 'to' to list
		for item in parts_dic['positional relationship']:
			item['to'] = [item['to']]
		relative_set(final_dic, parts_dic['positional relationship'])
	return final_dic

#### 2.2.4 Whole Process

In [None]:
def merge_generate(desire, div="clip", merge="params"):
	if div not in ["gene", "clip", "prefabs"]:
		print(f"Invalid division method: {div}")
	if merge not in ["params", "hand"]:
		print(f"Invalid merge method: {merge}")

	if div == "prefabs":
		parts_dic, res = prefabs_division(desire)
		if not res:
			return parts_dic, False
		final_dic = merge_prefabs(parts_dic, merge)
	else:
		parts_dic, res = gene_division(desire)
		if not res:
			return parts_dic, False
		if div == "clip":
			final_dic = merge_clip(parts_dic, merge)
		else:
			final_dic = merge_gene(parts_dic, merge)

	final_dic['final_name'] = desire
	return final_dic, True

## 3. Interact

In [98]:
desire = "一根竖直的杆子，由五个同轴不同半径的圆柱体组成，圆柱面上下对齐拼接但相互不相交"
final_dic, res = merge_generate(desire, div="prefabs", merge="hand")
if res:
    filepath = available_path()
    save_json(final_dic, filepath)

Succeed
save to ./inter_result/generate2.json
