In [1]:
import json
from functools import partial

import bitsandbytes as bnb
import torch
from datasets import load_dataset
from peft import get_peft_model
from transformers import (
    AutoProcessor,
    DonutProcessor, VisionEncoderDecoderModel,
    Trainer,
    TrainingArguments,
)
import datetime
from transformers import VisionEncoderDecoderConfig


max_length = 164
image_size = [2560, 1920]

# update image_size of the encoder
# during pre-training, a larger image size was used
config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")
config.encoder.image_size = image_size # (height, width)
# update max_length of the decoder (for generation)
config.decoder.max_length = max_length


run_name = "donut-base"

# Get the current timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")

# Add the timestamp to the run name
run_name_with_timestamp = f"{run_name}_{timestamp}"

model_id = "naver-clova-ix/donut-base"
processor = DonutProcessor.from_pretrained(model_id, config=config)

def collate_fn(processor: AutoProcessor, examples):
    return processor(
            [example['image'].convert('RGB') for example in examples],
            text = [
                        json.dumps({k: v for k, v in example.items() if k != "image"})
                        for example in examples
            ],
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=config.decoder.max_length
        )


collate = partial(collate_fn, processor)

dataset = load_dataset("arnaudstiegler/synthetic_us_passports_easy")

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  from .autonotebook import tqdm as notebook_tqdm
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



In [13]:
model = VisionEncoderDecoderModel.from_pretrained(model_id, config=config)

# Donut requires that
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

model.gradient_checkpointing_enable()

In [14]:
args = TrainingArguments(
    output_dir="/Users/arnaudstiegler/Desktop/test/",
    num_train_epochs=4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=6,
    gradient_accumulation_steps=1,
    warmup_steps=1000,
    learning_rate=1e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=50,
    eval_steps=1, # not working anyway
    optim="paged_adamw_8bit",
    # optim='adamw_torch',
    save_strategy="steps",
    save_steps=500,
    push_to_hub=True,
    save_total_limit=1,
    bf16=True,
    run_name=run_name_with_timestamp,
    report_to=["wandb"],
    dataloader_pin_memory=False,
    remove_unused_columns=False,
)


In [15]:
trainer = Trainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset["test"],
    data_collator=collate,
    args=args,
)

In [16]:
trainer.train()

  0%|          | 0/39000 [12:58<?, ?it/s]
  0%|          | 0/39000 [00:00<?, ?it/s]Unused or unrecognized kwargs: truncation, max_length, padding.
Unused or unrecognized kwargs: truncation, max_length, padding.
`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`...


RuntimeError: MPS backend out of memory (MPS allocated: 11.61 GB, other allocations: 2.02 GB, max allowed: 13.57 GB). Tried to allocate 150.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).