In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import RandomCrop, ToPILImage
from torchvision.transforms.functional import crop
from tqdm import tqdm
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
)
from datasets import load_metric

In [None]:
class ContainerOCRDatasetText(Dataset):
    def __init__(self, directory, processor, is_train=True):
        if directory is None:
            print("Directory is none")
            return
        if processor is None:
            print("Processor is none")
            return
        self.processor = processor
        self.directory = Path(directory)
        self.image_label = []
        if is_train:
            self.decode(
                file_path=str(self.directory.joinpath("train/_annotations.coco.json")),
                is_train=is_train,
            )
        else:
            self.decode(
                file_path=str(self.directory.joinpath("valid/_annotations.coco.json")),
                is_train=is_train,
            )

    def decode(self, file_path: str, is_train=True):
        with open(file_path) as file:
            jsonData = json.load(file)
            for image in jsonData["images"]:
                image_id = image["id"]
                image_filename = image["file_name"]
                for annotation in jsonData["annotations"]:
                    if annotation["image_id"] == image_id:
                        bounding_box = annotation["bbox"]
                        x1, y1 = int(bounding_box[0]), int(bounding_box[1])
                        x2, y2 = x1 + int(bounding_box[2]), y1 + int(bounding_box[3])
                        self.image_label.append(
                            {
                                "image_filename": f'{"train" if is_train else "valid"}/{image_filename}',
                                "bbox": [x1, y1, x2, y2],
                            }
                        )

    def __len__(self):
        return len(self.image_label)

    def __getitem__(self, index):
        image = self.decode_image(
            str(self.directory.joinpath(f"{self.image_label[index]['image_filename']}"))
        )
        text = self.image_label[index]["image_filename"].split("_")[1]
        x1, y1, x2, y2 = self.image_label[index]["bbox"]
        original_image = image[..., y1:y2, x1:x2]
        image = self.processor(original_image * 255, return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(
            text, padding="max_length", max_length=10
        ).input_ids
        labels = [
            label if label != self.processor.tokenizer.pad_token_id else -100
            for label in labels
        ]
        return image[0], torch.tensor(labels), text, original_image

    @staticmethod
    def decode_image(image_path):
        return read_image(image_path, ImageReadMode.RGB)

In [2]:
device = torch.device("mps")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ocrModel = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
ocrModel = ocrModel.to(device)

ocrModel.config.decoder_start_token_id = processor.tokenizer.cls_token_id
ocrModel.config.pad_token_id = processor.tokenizer.pad_token_id
ocrModel.config.eos_token_id = processor.tokenizer.sep_token_id
ocrModel.config.max_length = 10
ocrModel.config.no_repeat_ngram_size = 3
ocrModel.config.length_penalty = 2.0
ocrModel.config.num_beams = 4
optimizer = torch.optim.AdamW(ocrModel.parameters(), lr=5e-5)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
processor.tokenizer(['hello there', 'how are you'])

{'input_ids': [[0, 42891, 89, 2], [0, 9178, 32, 47, 2]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1, 1]]}

In [5]:
processor(torch.randint(size=[3, 3, 256, 256], low=0, high=255))

{'pixel_values': [array([[[-0.78039217, -0.372549  , -0.00392157, ...,  0.4431373 ,
          0.15294123, -0.12156862],
        [-0.5529412 , -0.19999999,  0.03529418, ...,  0.27843142,
         -0.0745098 , -0.38039213],
        [-0.3098039 ,  0.03529418,  0.18431377, ...,  0.18431377,
         -0.16862744, -0.45098037],
        ...,
        [ 0.56078434, -0.09019607, -0.64705884, ...,  0.33333337,
          0.39607847,  0.5529412 ],
        [ 0.4666667 , -0.18431371, -0.73333335, ..., -0.23137254,
          0.09803927,  0.5764706 ],
        [ 0.34901965, -0.3098039 , -0.827451  , ..., -0.7490196 ,
         -0.17647058,  0.5764706 ]],

       [[ 0.7019608 ,  0.16078436, -0.2862745 , ..., -0.70980394,
         -0.34117645,  0.07450986],
        [-0.09803921, -0.38039213, -0.5764706 , ...,  0.02745104,
          0.26274514,  0.48235297],
        [-0.827451  , -0.79607844, -0.6862745 , ...,  0.52156866,
          0.5764706 ,  0.5921569 ],
        ...,
        [ 0.3803922 , -0.1607843 , -

In [None]:
for sample in ContainerOCRDatasetText("container_dataset/", processor=processor):
    processed_image, image, label = sample[0], sample[-1], sample[1]

    if image.size(1) < image.size(2):
        continue

    fig, axes = plt.subplots(ncols=2)
    axes[1].imshow(torch.permute(image, [1, 2, 0]))
    axes[0].imshow(torch.permute(processed_image, [1, 2, 0]))

In [None]:
# ocrModel.load_state_dict(torch.load('./data/model/33_model.pt', map_location=torch.device('cpu')))

In [None]:
training_data = ContainerOCRDatasetText(
    directory="./container_dataset/",
    processor=processor,
    is_train=True,
)

train_dataloader = DataLoader(
    training_data,
    batch_size=1,
    shuffle=True,
)

test_data = ContainerOCRDatasetText(
    directory="./container_dataset/",
    processor=processor,
    is_train=False,
)

test_dataloader = DataLoader(
    test_data,
    batch_size=5,
)

for sample in train_dataloader:
    print(sample[1][0])
    break

# for sample in train_dataloader:
#     plt.figure()
#     print(sample[0].size())
#     plt.imshow(torch.permute(sample[0].squeeze(), [1, 2, 0]).cpu().numpy())

# plt.figure()
# plt.imshow(torch.permute(sample[2].squeeze(), [1, 2, 0]).cpu().numpy())
# break

In [None]:
cer_metric = load_metric("cer")


def compute_cer(pred_ids, label_ids):
    sum_cer = 0
    for pred, label in zip(pred_ids, label_ids):
        pred_str = processor.decode(pred, skip_special_tokens=True)
        label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
        label_str = processor.decode(label, skip_special_tokens=True)
        print(f'"{label_str}"  "{pred_str}"')

        if pred_str == "":
            sum_cer += len(label_str)
        elif label_str == "":
            sum_cer += len(pred_str)
        else:
            sum_cer += cer_metric.compute(
                predictions=[pred_str], references=[label_str]
            )
        print(sum_cer)
    return sum_cer / len(pred_ids)

In [None]:
for epoch in range(0, 200):
    # train_running_loss = 0.0
    # ocrModel.train()
    # for idx, data in enumerate(tqdm(train_dataloader)):

    #     inputs: torch.Tensor
    #     labels: torch.Tensor
    #     inputs, labels = data[0], data[1]

    #     inputs = inputs.to(device)
    #     labels = labels.to(device)

    #     output = ocrModel(inputs, labels=labels)
    #     output.loss.backward()
    #     train_running_loss += output.loss.item()
    #     optimizer.step()
    #     optimizer.zero_grad()

    validation_cer = 0
    ocrModel.eval()
    for data in test_dataloader:
        with torch.no_grad():
            inputs: torch.Tensor
            labels: torch.Tensor
            inputs, labels = data[0], data[1]

            inputs = inputs.to(device)
            labels = labels.to(device)

            generated_ids = ocrModel.generate(inputs)

            cer = compute_cer(generated_ids, labels)
            validation_cer += cer
    print(validation_cer / len(test_dataloader))

In [None]:
test_data = ContainerOCRDatasetText(
    directory="./container_dataset/",
    processor=processor,
    is_train=False,
)

for idx, sample in enumerate(test_data):
    with torch.no_grad():
        plt.figure()
        plt.imshow(torch.permute(sample[0], dims=[1, 2, 0]))

        sample[1][sample[1] == -100] = processor.tokenizer.pad_token_id
        generated_ids = ocrModel.generate(sample[0].to("mps").unsqueeze(dim=0))
        generated_text = processor.batch_decode(
            generated_ids, skip_special_tokens=True
        )[0]
        print(generated_text, processor.decode(sample[1], skip_special_tokens=True))

    if idx == 0:
        break