In [1]:
from transformers import AutoProcessor, GroundingDinoForObjectDetection,GroundingDinoConfig
import torch
from PIL import Image

from transformers import Trainer, TrainingArguments

from Huggingface_agent.finetune.loss_utils import loss_helper
from Huggingface_agent.finetune.finetune_utils import get_finetune_model, get_dataset

2024-12-24 10:21:43.097091: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
device = 'cuda' if torch.cuda.is_available() else "cpu" 
processor = AutoProcessor.from_pretrained('./model')
config = GroundingDinoConfig(bbox_loss_coefficient = 10.0, giou_loss_coefficient = 10.0)
model = GroundingDinoForObjectDetection.from_pretrained('./model',config=config).to(device)
finetune_model = get_finetune_model(model,device)
train_dataset = get_dataset()

def data_collator(batch:list[dict]):
        paths = [ image['image_path'] for image in batch]
        Images = [Image.open(f"./{i}").convert("RGB") for i in paths]
        text = ['an interface. an icon.' for _ in range(len(paths))]
        inputs = processor(images=Images, text=text,return_tensors='pt').to(device)
        inputs['labels'] = {
            'input_ids': inputs['input_ids'],
            'target_sizes': [image.size[::-1] for image in Images],
            'bbox': [image['range'] for image in batch],
            'icon_num': [image['icon_num'] for image in batch] 
        }
        return inputs

def loss(outputs, labels, **kwargs):
    # Training with box threshold is zero.
    res = processor.post_process_grounded_object_detection(
        outputs,
        labels['input_ids'],
        box_threshold=0.,
        text_threshold=0.,
        target_sizes=[i[::-1] for i in labels['target_sizes']]
    )
    return loss_helper(res, labels, device)

train_args = TrainingArguments(
        output_dir='./ckpt',
        do_train=True,
        do_eval=False,
        per_device_train_batch_size=1,
        num_train_epochs=20,
        save_steps=50,
        # log_level = 'info',
        logging_steps=20,
        torch_empty_cache_steps=500,
        learning_rate=5e-1,
        report_to = "tensorboard",
        remove_unused_columns=False,
        dataloader_pin_memory=False,
        use_cpu = True if not torch.cuda.is_available() else False
)

trainer = Trainer(
    model = finetune_model,
    args = train_args,
    train_dataset=train_dataset,
    processing_class=processor,
    data_collator=data_collator,
    compute_loss_func=loss
)

In [5]:
model = GroundingDinoForObjectDetection.from_pretrained('./model').to(device)
model

GroundingDinoForObjectDetection(
  (model): GroundingDinoModel(
    (backbone): GroundingDinoConvModel(
      (conv_encoder): GroundingDinoConvEncoder(
        (model): SwinBackbone(
          (embeddings): SwinEmbeddings(
            (patch_embeddings): SwinPatchEmbeddings(
              (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
            )
            (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (encoder): SwinEncoder(
            (layers): ModuleList(
              (0): SwinStage(
                (blocks): ModuleList(
                  (0): SwinLayer(
                    (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                    (attention): SwinAttention(
                      (self): SwinSelfAttention(
                        (query): Linear(in_features=96, out_features=96, bias=True)
                        (key): Linear(in_features=9

In [11]:
trainer.train()

Step,Training Loss
20,2.1377
40,2.0972
60,2.3146
80,1.9668
100,1.7387
120,1.6736
140,1.6136
160,1.6371
180,1.6556
200,1.5748


TrainOutput(global_step=2900, training_loss=1.6051822116457182, metrics={'train_runtime': 2404.1778, 'train_samples_per_second': 1.206, 'train_steps_per_second': 1.206, 'total_flos': 7.62396412045337e+18, 'train_loss': 1.6051822116457182, 'epoch': 20.0})