In [1]:
%pip install -q -U anthropic
%pip install -q -U lightning
%pip install -q -U bitsandbytes
%pip install -q -U transformers
%pip install -q -U peft
%pip install -q -U accelerate
%pip install -q -U wandb
%pip install -q -U datasets
%pip install -q -U prodigyopt
%pip install pdf2image
%pip install -q -U protobuf==3.20.0
!apt-get update --yes
!apt-get install -y poppler-utils

In [1]:
import torch
import os
import json
import random
from transformers import AutoProcessor, VisionEncoderDecoderModel
from torch.utils.data import Dataset, DataLoader
import torchvision as tv
from transformers import DonutProcessor
from torch.utils.data import DataLoader
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv
import lightning as L
from transformers import DonutProcessor, VisionEncoderDecoderModel
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import wandb
import os
from IPython.display import Image
from tqdm.auto import tqdm
from IPython.display import Image
import pickle
import torchvision.transforms as T
from datasets import Dataset
from datasets import IterableDataset
from tqdm.auto import tqdm
import unicodedata
import os
import gc
from torch.utils.data import Dataset
import torchvision as tv
import torchvision.transforms as T
from transformers import DonutProcessor
import random
from torch.utils.data import Subset
from peft import LoraConfig
from torch.optim.lr_scheduler import LinearLR
import lightning as L
from prodigyopt import Prodigy
from datetime import datetime
import torch.multiprocessing as mp
from torch.cuda.amp import autocast, GradScaler
from pdf2image import convert_from_path, convert_from_bytes
from transformers import VisionEncoderDecoderModel, DonutProcessor
from pytorch_lightning.loggers import WandbLogger
from pdf2image import convert_from_path
from torchvision.transforms import ToTensor

In [3]:
DATASET_PATH = "./documents"
PDF_PATH = os.path.join(DATASET_PATH, "german_pdf_files")
IMAGE_PATH = os.path.join(DATASET_PATH, "german_img_files")


class DocMeta:
    def __init__(self, dataset_path=DATASET_PATH, image_path=IMAGE_PATH):
        self.dataset_path = dataset_path
        self.image_path = image_path
        self.extraction = self.load_extraction(self.dataset_path)
        self.mapping = self.load_mapping(self.image_path)

    def get_random_urls(self, num=100):
        return random.sample(list(self.extraction.keys()), num)

    def load_extraction(self, path):
        with open(os.path.join(path, "extraction.json"), "r") as f:
            return json.load(f)

    def load_mapping(self, path):
        with open(os.path.join(path, "mapping.json"), "r") as f:
            return json.load(f)

    def get_image_path(self, url):
        return os.path.join(self.image_path, self.mapping[url])

    def get_image(self, url):
        img_path = self.get_image_path(url)
        return Image(filename=img_path)


doc_meta = DocMeta()

In [4]:
class DocMetaDataset(Dataset):
    def __init__(self, doc_meta, output_dir, ocr=False):
        self.doc_meta = doc_meta
        self.task_start_token = "<s>"
        self.eos_token = "</s>"
        self.processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
        self.processor.tokenizer.add_special_tokens(
            {"additional_special_tokens": [self.task_start_token, self.eos_token]}
        )
        self.processor.feature_extractor.size = [2160, 3840]
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.data = self._prepare_data()
        self.ocr = ocr

    def _prepare_data(self):
        data = []
        for url in tqdm(
            list(self.doc_meta.extraction.keys()), desc="Preparing dataset"
        ):
            main_author = self.doc_meta.extraction[url]["metadata"]["main_author"]

            transcript = self.doc_meta.extraction[url]["transcript"]

            transcript = self.task_start_token + transcript + self.eos_token

            if not main_author is None:

                if "name" in main_author:

                    author = main_author["name"]

                elif "first_name" in main_author and "last_name" in main_author:

                    author = ""
                    if main_author["first_name"] is not None:
                        author += main_author["first_name"]
                    if main_author["last_name"] is not None:
                        author += " " + main_author["last_name"]

                elif "department" in main_author:

                    author = main_author["department"]

                else:
                    continue
            else:
                continue

            author = self.task_start_token + author + self.eos_token
            image_path = self.doc_meta.get_image_path(url)

            image_data = tv.io.read_image(unicodedata.normalize("NFC", image_path))
            res = tv.transforms.functional.get_image_size(image_data)

            if res[0] > 2160 or res[1] > 3840:
                continue

            if author.replace(" ", "") == "<s></s>":
                continue

            if len(author) > 80:
                continue

            data.append((image_path, author, url, transcript))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, author, url, transcript = self.data[idx]

        image_data = tv.io.read_image(unicodedata.normalize("NFC", image_path))
        image_data = image_data.permute(1, 2, 0)

        pixel_values = self.processor(image_data, return_tensors="pt").pixel_values
        pixel_values = pixel_values.squeeze()

        if not self.ocr:
            input_ids_author = (
                self.processor.tokenizer(
                    author,
                    add_special_tokens=False,
                    padding="max_length",
                    truncation=True,
                    max_length=80,
                    return_tensors="pt",
                )["input_ids"]
                .squeeze(0)
                .long()
            )

            author_target = input_ids_author.clone()
            author_target[author_target == self.processor.tokenizer.pad_token_id] = -100
            return {
                "pixel": pixel_values,
                "target": author_target,
                "url": url,
            }

        if self.ocr:
            input_ids_transcript = (
                self.processor.tokenizer(
                    transcript,
                    add_special_tokens=False,
                    padding="max_length",
                    truncation=True,
                    max_length=300,
                    return_tensors="pt",
                )["input_ids"]
                .squeeze(0)
                .long()
            )
            transcript_target = input_ids_transcript.clone()
            transcript_target[
                transcript_target == self.processor.tokenizer.pad_token_id
            ] = -100
            return {
                "pixel": pixel_values,
                "target": transcript_target,
                "url": url,
            }

    def get_split(self, train_ratio=0.8, val_ratio=0.1, seed=None):
        total_size = len(self.data)
        train_size = int(train_ratio * total_size)
        val_size = int(val_ratio * total_size)
        test_size = total_size - train_size - val_size  # Ensure all data is used

        indices = list(range(total_size))
        if seed is not None:
            random.Random(seed).shuffle(indices)
        else:
            random.shuffle(indices)

        train_indices = indices[:train_size]
        val_indices = indices[train_size : train_size + val_size]
        test_indices = indices[train_size + val_size :]

        train_dataset = Subset(self, train_indices)
        val_dataset = Subset(self, val_indices)
        test_dataset = Subset(self, test_indices)

        return train_dataset, val_dataset, test_dataset

In [5]:
dataset = DocMetaDataset(doc_meta, "./")

In [6]:
train, val, test = dataset.get_split(0.8, 0.18)
len(train), len(val), len(test)

In [7]:
dataset.ocr = False

In [8]:
batch_size = 1
num_workers = 0

train_dataloader = DataLoader(
    train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
)

val_dataloader = DataLoader(
    val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
)

test_dataloader = DataLoader(
    test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
)

In [2]:
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")

model.decoder.resize_token_embeddings(len(dataset.processor.tokenizer))
model.config.pad_token_id = dataset.processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = dataset.processor.tokenizer.convert_tokens_to_ids(
    ["<s>"]
)[0]


# peft_config = LoraConfig(
#    lora_alpha=128, lora_dropout=0.05, r=256, bias="none", target_modules="all-linear"
# )
#
# model = get_peft_model(model, peft_config)

model = model.to("cuda")

# model.print_trainable_parameters()

In [10]:
torch.set_float32_matmul_precision("medium")

In [11]:
wandb.login()

In [12]:
# compiled_model = torch.compile(model) # not working

In [13]:
scaler = GradScaler()  # for scaling the loss


optimizer = Prodigy(model.parameters(), lr=1.0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=10 * len(train_dataloader)
)


run = wandb.init(
    # Set the project where this run will be logged
    project="donut_torch",
)


def train_one_epoch(epoch_index):
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(train_dataloader):
        batch = data

        optimizer.zero_grad()

        with autocast():

            outputs = model(
                pixel_values=batch["pixel"].to("cuda"),
                labels=batch["target"].to("cuda"),
            )

            loss = outputs.loss

        scaler.scale(loss).backward()  # why?
        scaler.step(optimizer)  # why?
        scaler.update()  # why?

        running_loss += loss.item()
        if i % 10 == 9:
            last_loss = running_loss / 10  # loss per batch
            print("  batch {} loss: {}".format(i + 1, last_loss))
            wandb.log({"train_loss": last_loss})
            running_loss = 0.0

        # Test loss
        if i % 200 == 190:
            running_vloss = 0
            with torch.no_grad():
                for i, vdata in enumerate(test_dataloader):
                    vbatch = vdata
                    with autocast():
                        voutputs = model(
                            pixel_values=vbatch["pixel"].to("cuda"),
                            labels=vbatch["target"].to("cuda"),
                        )
                        vloss = voutputs.loss
                    running_vloss += vloss
            avg_vloss = running_vloss / (i + 1)
            print("  vloss: {}".format(avg_vloss))
            wandb.log({"v_loss": avg_vloss})

    scheduler.step()

    return last_loss

In [41]:
# TODO: epochs too long

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.0


for epoch in range(EPOCHS):
    print("EPOCH {}:".format(epoch_number + 1))

    model.train(True)
    avg_loss = train_one_epoch(epoch_number)

    running_vloss = 0.0

    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(val_dataloader):
            vbatch = vdata
            with autocast():
                voutputs = model(
                    pixel_values=vbatch["pixel"].to("cuda"),
                    labels=vbatch["target"].to("cuda"),
                )
                vloss = voutputs.loss
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print("LOSS train {} valid {}".format(avg_loss, avg_vloss))
    wandb.log({"val_loss": avg_vloss})

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        print("Saving model ...")
        best_vloss = avg_vloss
        model_path = "model_{}_{}".format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)
        wandb.save(model_path)

    epoch_number += 1

# Storage and Quantization


In [15]:
from huggingface_hub import login

login()

In [16]:
# does not work when using peft LoRA
model.push_to_hub("doc_meta")

# Storage and inference


In [21]:
task_start_token = "<s>"
eos_token = "</s>"
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
processor.tokenizer.add_special_tokens(
    {"additional_special_tokens": [task_start_token, eos_token]}
)

processor.feature_extractor.size = [2160, 3840]
model = VisionEncoderDecoderModel.from_pretrained("sodowo/doc_meta")

model.decoder.resize_token_embeddings(len(processor.tokenizer))
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(
    ["<s>"]
)[0]

In [3]:
eval_model = model.eval()

In [4]:
pdf_path = "./uni_hd.pdf"
image_data = convert_from_path(pdf_path, fmt="jpeg", first_page=1, last_page=1)[0]
image_path = pdf_path.replace(".pdf", ".jpg")
image_data.save(image_path)

In [5]:
image_data = tv.io.read_image(unicodedata.normalize("NFC", image_path))
image_data = image_data.permute(1, 2, 0)

In [6]:
pixel_values = processor(image_data, return_tensors="pt").pixel_values
pixel_values = pixel_values.squeeze()

In [7]:
eval_model = eval_model.to("cuda")
with torch.no_grad():
    generated_ids = model.generate(pixel_values.unsqueeze(0).to("cuda"))

In [8]:
generated_ids

In [9]:
predicted_label = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [10]:
predicted_label