In [1]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from datasets import load_dataset
import matplotlib.pyplot as plt

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(device)

Downloading (…)rocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Downloading tokenizer_config.json:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/4.13k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

: 

In [None]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
class ImageCaptioningDataset(Dataset):

    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        item = self.dataset[idx]

        start = item['text'].find("<field>") + 7
        stop = item['text'].find("</field>")

        target_text_x = item['text'][start:stop]

        start = item['text'].find("<field>", stop+1) + 7
        stop = item['text'].find("</field>", stop+1)

        target_text_y = item['text'][start:stohandwrittenp]

        item['text'] = target_text_x + "\n" + target_text_y

        pixel_values = self.processor(item["image"].convert("RGB"), return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(item["text"], padding="max_length", max_length=20).input_ids

        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        
        return encoding

In [None]:
train_dataset = ImageCaptioningDataset(load_dataset("martinsinnona/visdecode", split = "train"), processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

test_dataset = ImageCaptioningDataset(load_dataset("martinsinnona/visdecode", split = "test"), processor)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-5)

epochs = 10
losses = []
batch_loss = 0

model.to(device)
model.train()

for epoch in range(epochs):
    for idx, batch in enumerate(train_dataloader):

        if idx % 100 == 0: print(idx, " /", len(train_dataloader))

        for k,v in batch.items():
          batch[k] = v.to(device)

        # Generating captions
        output = model(**batch)

        # Compute the loss
        loss = output.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        batch_loss += loss.cpu().detach().numpy().item()

    batch_loss = batch_loss / len(train_dataloader)

    print("Epoch: ", epoch, " | batch mean loss:", batch_loss)
    losses.append(batch_loss)

    batch_loss = 0

plt.plot(losses)

In [None]:
for batch in test_dataset:

  pixel_values = batch["pixel_values"].unsqueeze(0).to(device)

  generated_ids = model.generate(pixel_values)
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

  print(generated_text)