In [None]:
!pip install fsspec==2023.9.2 --quiet

In [None]:
!huggingface-cli login --token <token-here>

In [None]:
from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTImageProcessor, ViTModel
from transformers import RoFormerForCausalLM, RoFormerConfig
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import torch

In [None]:
tokenizer_path = 'openai-community/gpt2'

In [None]:
vision_encoder_id = "google/vit-base-patch16-224-in21k"

vision_encoder = ViTModel.from_pretrained(vision_encoder_id)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

tokenizer.pad_token = tokenizer.eos_token

tokenizer.eos_token

In [None]:
decoder_config = RoFormerConfig(
    vocab_size=len(tokenizer),
    hidden_size=512,
    num_hidden_layers=12,
    num_attention_layers=8,
    num_attention_heads=4,
    intermediate_size=1024,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.3,
    max_position_embeddings=2048,
    type_vocab_size=1,
    pad_token_id=tokenizer.eos_token_id,
    is_decoder=True,
    add_cross_attention=True
)

decoder = RoFormerForCausalLM(decoder_config)

In [None]:
image_processor = ViTImageProcessor.from_pretrained(vision_encoder_id)

In [None]:
model = VisionEncoderDecoderModel(encoder=vision_encoder, decoder=decoder)

model.config.decoder.is_decoder = True
model.config.decoder.add_cross_attention = True
model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.pad_token_id

In [None]:
model

In [None]:
ds = load_dataset("CADCODER/GenCAD-Code", num_proc=4).remove_columns(['token_count', "deepcad_id", "hundred_subset", "prompt"])

ds

In [None]:
def filter_queries_by_length(example):
    # Work on small subset of data for POC
    text = example['cadquery']

    tokens = tokenizer(text)
    return len(tokens[0]) <= 1022

In [None]:
train_data = ds['train'].select(range(4096)).filter(filter_queries_by_length) 

In [None]:
valid_data = ds['validation'].select(range(128)).filter(filter_queries_by_length)

In [None]:
test_data = ds['test'].select(range(1024)).filter(filter_queries_by_length)

In [None]:
train_data, valid_data

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=3e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.0001,
    optim="adamw_torch",
    eval_strategy="steps",
    save_total_limit=3,
    eval_steps=100,
    save_steps=100,
    logging_steps=100,
    report_to="none",
    load_best_model_at_end=True,
    fp16=True,
    push_to_hub=True,
    hub_model_id="khairi/SmolLM-Vit-CAD",
    remove_unused_columns=False,
)

In [None]:

def collate_fn(batch):    
    images = [e['image'] for e in batch]
    inputs = [e['cadquery'] for e in batch]

    image_features = image_processor.preprocess(images, return_tensors="pt")
    inputs = tokenizer(inputs, return_tensors="pt", padding='max_length', max_length=2048)
    inputs['labels'] = inputs['input_ids'].clone()
    inputs['decoder_attention_mask'] = inputs['attention_mask'].clone()
    inputs['decoder_input_ids'] = inputs['input_ids'].clone()
    inputs.pop('attention_mask')
    inputs.pop('input_ids')
    return {**image_features, **inputs}

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_data,
    eval_dataset=valid_data
)

In [None]:
trainer.train()