In [1]:
from datasets import load_dataset
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image
import torch

In [2]:
ds = load_dataset(
    "CADCODER/GenCAD-Code",
    num_proc=16,
    split=["train", "test"],
    cache_dir="./Volumes/BIG-DATA/HUGGINGFACE_CACHE",
)

train_dataset, test_dataset = ds

In [3]:
test_dataset

Dataset({
    features: ['image', 'deepcad_id', 'cadquery', 'token_count', 'prompt', 'hundred_subset'],
    num_rows: 7355
})

In [4]:
# Load the model
model = VisionEncoderDecoderModel.from_pretrained("Thehunter99/vit-codegpt-cadcoder", device_map="cuda")
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
tokenizer = AutoTokenizer.from_pretrained("microsoft/CodeGPT-small-py")

In [7]:
# Load and process image
image = train_dataset[0]["image"]
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values

# Generate CAD code
with torch.no_grad():
    generated_ids = model.generate(
        pixel_values.cuda(),
        max_length=1024,
        # num_beams=4,
        # early_stopping=True,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id
    )

generated_code = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_code)


 cad as
 Generating work for 0wpsketch =.(.,.,.,.)solid=00125 cqWork= cqWork(.(.,.,.,.),.,.),.(.,.,.)solid=00125
0wpsketch=_0001 cqWork=_0001 cqWork0=_001 cqWork000 cqWork000 cqWork000 cqWork000 cqWork000 cqWork000 cqWork000 cqWork0000 cqWork0000 cqWork0000 cqWork0000 0000 0000 000 000 000 00close
=_000125 00close
=_00125
 Generating work for 1wpsketch =.(.56,.).(.56)solid=00125 cqWork=_001(.,.).(.56)solid=00125
0wpsketch=_001 cqWork=_001(.,.).(.56)solid=00125 cqWork=_001(.,.).(.56)solid=00125 cqWork=_001(.,.).(.56)solid=00125 cqWork=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125 cqWork=_001(.,.).(.56)solid=00125 cqWork=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=00125
0wpsketch=_001(.,.).(.56)solid=0