In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")
from ocr_ensemble.data import load_dataset
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from functools import partial
from collections import defaultdict
from copy import deepcopy
import torch

from paddleocr import draw_ocr
from PIL import Image

In [12]:
from ocr_ensemble.classifiers import ClipEmbedding, ClipMulticlass, ClipPresence
from ocr_ensemble.proposers import PaddleOCRProposalGenerator
from ocr_ensemble.proposers import rotatedCrop
from ocr_ensemble.experts import HandwrittenExpert, Stage1Expert, PaddleOCRExpert, PrintedExpert, SceneExpert
from ocr_ensemble.data import identity

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
expert_size = 'large'
parquet_fname = '../data/laion2b-en-10K.parquet'
parquet_fname = '../data/laion2b-en-1K.parquet'
parquet_result_fname = f'../data/laion2b-en-1K-experts-{expert_size}-4.parquet'
dataset_path = '../data/laion2b-en-1K-large'


In [14]:
dataset = load_dataset(dataset_path, parquet_fname, image_size=512, number_sample_per_shard=100)

# compute subset of images that contain text

In [15]:
clip_emb = ClipEmbedding()
presence = ClipPresence(clip_emb=clip_emb)

In [16]:
dataset_clip = deepcopy(dataset)
dataset_clip.map_tuple(presence.get_transform(), identity, identity)

<webdataset.compat.WebDataset at 0x1eb4ecf0df0>

In [17]:
loader = DataLoader(dataset_clip, batch_size=200, num_workers=4)

In [18]:
features = []
labels = []
label_dict = {}
for imgs, captions, keys in tqdm(loader):
    preds = presence.predict(imgs.to(device))
    labels += [preds]
    for key, pred in zip(keys, preds):
        label_dict[key] = pred
    
labels = np.concatenate(labels, axis=0)

6it [00:16,  2.72s/it]


# compute & label bounding boxes

In [19]:
from torch.utils.data import default_collate

def collate(batch):
    imgs = []
    bboxes = []
    labels = []
    for img, bbox, label in batch:
        imgs += [img]
        bboxes += [bbox]
        labels += [label]
    return default_collate(imgs), bboxes, labels

def get_crops(src, label_dict, proposer):
    for img, caption, key in tqdm(src):
        if label_dict[key] == 1:
            crops, bboxes = proposer(img)
            for crop, bbox in zip(crops, bboxes):
                yield crop, bbox, key

def get_imgs_containing_text(src, label_dict):
    for img, caption, key in tqdm(src):
        if label_dict[key] == 1:
            yield img, caption, key

proposer = PaddleOCRProposalGenerator(device='cpu')
dataset_crops = deepcopy(dataset)
dataset_crops.compose(partial(get_crops, label_dict=label_dict, proposer=proposer))

<webdataset.compat.WebDataset at 0x1eb79b08520>

In [22]:
expert_text_dict = {f"trocr-{expert_size}-handwritten": "handwritten text, handwriting, black on white",
                    f"trocr-{expert_size}-printed": "text in a document, website, or presentation",
                    f"trocr-{expert_size}-str": "text in a scene",
                    f"trocr-{expert_size}-stage1": "rendered text"}
expert_dict = {f"trocr-{expert_size}-handwritten": HandwrittenExpert(expert_size),
               f"trocr-{expert_size}-printed": PrintedExpert(expert_size),
               f"trocr-{expert_size}-str": SceneExpert(expert_size),
               f"trocr-{expert_size}-stage1": Stage1Expert(expert_size)}

clf = ClipMulticlass(list(expert_text_dict.values()),
                     clip_emb = clip_emb)

dataset_crops.map_tuple(clf.get_transform(), identity, identity)
loader_crops = DataLoader(dataset_crops, 
                          batch_size=200,
                          collate_fn=collate) # here num_worker breaks things in non-trivial ways

microsoft/trocr-large-handwritten


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-large-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


microsoft/trocr-large-printed


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-large-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


microsoft/trocr-large-str


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


microsoft/trocr-large-stage1


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-large-stage1 and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


only supported targets ['handwritten', 'printed', 'scene', 'stage1']


In [23]:
bbox_dict = defaultdict(list)
bbox_label_dict = defaultdict(list)
bbox_scores_dict = defaultdict(list)

for crops, bboxes, keys in tqdm(loader_crops):
    labels = clf.predict(crops.to(device))

    for label, bbox, key in zip(labels, bboxes, keys):
        bbox_dict[key] += [bbox]
        bbox_label_dict[key] += [label]
        #bbox_scores_dict[key] += [score.tolist()]

0it [00:00, ?it/s]
0it [00:00, ?it/s]
3it [00:00,  5.43it/s]
4it [00:01,  2.66it/s]
6it [00:01,  3.55it/s]
7it [00:02,  3.03it/s]
9it [00:02,  3.94it/s]
11it [00:02,  4.98it/s]
12it [00:03,  4.07it/s]
13it [00:03,  3.51it/s]
14it [00:04,  2.94it/s]
17it [00:04,  4.38it/s]
18it [00:04,  4.32it/s]
19it [00:04,  4.11it/s]
22it [00:05,  5.87it/s]
26it [00:05,  8.03it/s]
27it [00:06,  4.91it/s]
28it [00:06,  4.70it/s]
31it [00:06,  5.61it/s]
34it [00:07,  6.79it/s]
39it [00:07,  8.50it/s]
40it [00:07,  7.74it/s]
43it [00:08,  8.32it/s]
45it [00:09,  3.74it/s]
46it [00:09,  3.34it/s]
48it [00:10,  3.91it/s]
50it [00:10,  4.04it/s]
51it [00:11,  3.65it/s]
52it [00:11,  3.74it/s]
54it [00:11,  4.72it/s]
56it [00:11,  5.11it/s]
58it [00:12,  4.86it/s]
59it [00:12,  4.40it/s]
63it [00:13,  6.74it/s]
65it [00:13,  6.15it/s]
1it [00:14, 14.53s/it]]
68it [00:14,  3.70it/s]
69it [00:14,  3.68it/s]
70it [00:16,  2.05it/s]
71it [00:16,  2.31it/s]
75it [00:16,  4.21it/s]
77it [00:17,  4.47it/s]
78it [0

577it [02:08,  5.60it/s]
9it [02:09, 14.34s/it]s]
579it [02:09,  3.16it/s]
580it [02:09,  3.36it/s]
581it [02:09,  3.55it/s]
587it [02:10,  7.38it/s]
591it [02:10,  9.30it/s]
592it [02:10,  7.58it/s]
594it [02:11,  7.70it/s]
596it [02:11,  7.66it/s]
597it [02:11,  6.63it/s]
599it [02:11,  7.09it/s]
600it [02:12,  6.46it/s]
601it [02:12,  5.49it/s]
605it [02:12,  8.09it/s]
607it [02:12,  8.05it/s]
608it [02:13,  6.67it/s]
609it [02:13,  5.68it/s]
612it [02:13,  6.94it/s]
613it [02:14,  6.32it/s]
614it [02:14,  5.64it/s]
619it [02:14,  8.78it/s]
622it [02:14,  9.38it/s]
623it [02:15,  7.63it/s]
627it [02:15,  8.32it/s]
628it [02:15,  6.96it/s]
629it [02:16,  5.98it/s]
630it [02:16,  5.07it/s]
631it [02:16,  4.49it/s]
634it [02:17,  6.32it/s]
637it [02:17,  7.68it/s]
639it [02:17,  6.46it/s]
640it [02:18,  5.31it/s]
641it [02:18,  4.94it/s]
643it [02:18,  5.23it/s]
644it [02:19,  3.76it/s]
645it [02:19,  3.26it/s]
646it [02:20,  3.29it/s]
648it [02:20,  4.29it/s]
651it [02:20,  5.83it/s]


# perform ocr with the experts

In [24]:
def get_efficient_and_filtered_crops(src, expert_idx, bbox_dict, bbox_label_dict):
    for img, caption, key in tqdm(src):
        bboxes = bbox_dict[key]
        labels = bbox_label_dict[key]
        for bbox_idx, (bbox, label) in enumerate(zip(bboxes, labels)):
            if label == expert_idx:
                crop = rotatedCrop(img, bbox)
                yield crop, bbox_idx, key

def collate_crops(batch, img_collate):
    crops = []
    bboxes = []
    keys = []
    for crop, bbox, key in batch:
        crops += [crop]
        bboxes += [bbox]
        keys += [key]
    return img_collate(crops), bboxes, keys

In [None]:
ocr_dict = defaultdict(list)
for idx, key in enumerate(expert_dict.keys()):
    print(f'Expert {key} ...')
    expert = expert_dict[key]
    dataset_filtered = deepcopy(dataset)
    dataset_filtered.compose(partial(get_efficient_and_filtered_crops,
                                     expert_idx = idx, 
                                     bbox_dict = bbox_dict,
                                     bbox_label_dict = bbox_label_dict))
    dataset_filtered.map_tuple(expert.get_transform(), identity, identity)
    if key is "paddleocr":
        collate_fn = partial(collate_crops, img_collate=expert.get_collate())
    else:
        collate_fn = collate
    loader_expert = DataLoader(dataset_filtered, 
                           batch_size=200,
                           collate_fn=collate_fn)

    
    for crops, bbox_ids, keys in tqdm(loader_expert):
        texts = expert.process_batch(crops)
        for text, bbox_idx, key in zip(texts, bbox_ids, keys):
            ocr_dict[key] += [(bbox_idx, text)]

Expert trocr-large-handwritten ...


0it [00:00, ?it/s]
0it [00:00, ?it/s]
13it [00:00, 123.81it/s]
26it [00:00, 117.87it/s]
40it [00:00, 125.81it/s]
53it [00:00, 124.56it/s]
66it [00:00, 126.44it/s]
79it [00:00, 125.54it/s]
92it [00:00, 125.37it/s]
105it [00:00, 122.63it/s]
118it [00:00, 123.36it/s]
131it [00:01, 124.59it/s]
144it [00:01, 126.19it/s]
157it [00:01, 113.78it/s]
169it [00:01, 106.68it/s]
181it [00:01, 109.94it/s]
193it [00:01, 110.58it/s]
206it [00:01, 115.98it/s]
219it [00:01, 119.63it/s]
232it [00:01, 116.25it/s]
244it [00:02, 105.99it/s]
258it [00:02, 113.83it/s]
270it [00:02, 99.71it/s] 
283it [00:02, 106.27it/s]
296it [00:02, 111.05it/s]
308it [00:02, 111.66it/s]
320it [00:02, 106.61it/s]
331it [00:02, 107.52it/s]
342it [00:02, 107.91it/s]
357it [00:03, 118.50it/s]
369it [00:03, 118.59it/s]
381it [00:03, 115.00it/s]
394it [00:03, 116.62it/s]
406it [00:03, 116.92it/s]
420it [00:03, 121.81it/s]
433it [00:03, 123.11it/s]
446it [00:03, 123.32it/s]
459it [00:03, 117.10it/s]
473it [00:04, 122.81it/s]
487it [

Expert trocr-large-printed ...


0it [00:00, ?it/s]
0it [00:00, ?it/s]
13it [00:00, 101.56it/s]
24it [00:00, 80.53it/s] 
34it [00:00, 73.00it/s]
45it [00:00, 47.87it/s]
51it [00:00, 44.09it/s]
62it [00:01, 57.14it/s]
70it [00:01, 36.22it/s]
79it [00:01, 44.56it/s]
89it [00:01, 54.52it/s]
97it [00:01, 53.78it/s]
104it [00:01, 56.23it/s]
114it [00:02, 66.02it/s]
124it [00:02, 74.14it/s]
133it [00:02, 68.35it/s]
141it [00:02, 68.68it/s]
149it [00:02, 56.81it/s]
156it [00:02, 47.17it/s]
163it [00:02, 49.18it/s]
173it [00:03, 59.16it/s]
182it [00:03, 59.30it/s]
190it [00:03, 51.84it/s]
198it [00:03, 57.72it/s]
206it [00:03, 43.12it/s]
212it [00:04, 34.35it/s]
217it [00:04, 32.18it/s]
221it [00:04, 27.47it/s]
229it [00:04, 33.49it/s]
242it [00:05, 35.16it/s]
253it [00:05, 45.95it/s]
259it [00:05, 36.74it/s]
264it [00:05, 23.89it/s]
271it [00:06, 22.48it/s]
276it [00:06, 25.31it/s]
288it [00:06, 38.61it/s]
301it [00:06, 53.36it/s]
309it [00:07, 23.56it/s]
319it [00:07, 31.16it/s]
327it [00:07, 37.15it/s]
335it [00:07, 43.15i

Expert trocr-large-str ...


0it [00:00, ?it/s]
0it [00:00, ?it/s]
4it [00:00, 29.63it/s]
7it [00:00, 21.28it/s]
12it [00:00, 26.59it/s]
15it [00:00, 17.48it/s]
18it [00:00, 19.11it/s]
21it [00:01, 20.44it/s]
27it [00:01, 19.95it/s]
31it [00:01, 21.96it/s]
40it [00:01, 33.60it/s]
44it [00:01, 34.01it/s]
48it [00:02, 12.06it/s]
51it [00:02, 12.56it/s]
55it [00:02, 15.59it/s]
58it [00:03, 13.10it/s]
61it [00:03, 15.20it/s]
65it [00:03, 16.10it/s]
68it [00:03, 17.78it/s]
71it [00:04, 10.29it/s]
75it [00:04, 13.57it/s]
78it [00:04, 13.20it/s]
81it [00:05, 12.19it/s]
88it [00:05, 19.68it/s]
91it [00:05, 18.25it/s]
97it [00:05, 25.00it/s]
101it [00:06, 13.32it/s]
104it [00:06, 14.07it/s]
108it [00:06, 16.26it/s]
112it [00:06, 19.79it/s]
115it [00:06, 19.04it/s]
122it [00:06, 27.33it/s]
1it [30:37, 1837.70s/it]
125it [30:37, 131.43s/it]
127it [30:38, 109.09s/it]
134it [30:38, 58.95s/it] 
140it [30:38, 37.73s/it]
144it [30:38, 28.10s/it]
152it [30:38, 16.25s/it]
157it [30:38, 11.82s/it]
161it [30:39,  9.02s/it]
165it [30:

Expert trocr-large-stage1 ...


0it [00:00, ?it/s]
0it [00:00, ?it/s]
4it [00:00, 11.33it/s]
7it [00:00, 12.61it/s]
12it [00:00, 17.62it/s]
14it [00:00, 14.31it/s]
18it [00:01, 19.27it/s]
23it [00:01, 25.49it/s]
27it [00:01, 17.29it/s]
31it [00:01, 20.71it/s]
39it [00:01, 28.64it/s]
45it [00:02, 26.18it/s]
49it [00:02, 26.81it/s]
53it [00:02, 22.15it/s]
60it [00:02, 29.95it/s]
68it [00:02, 38.80it/s]
73it [00:03, 26.55it/s]
78it [00:03, 24.27it/s]
82it [00:03, 21.63it/s]
89it [00:03, 28.90it/s]
93it [00:03, 28.97it/s]
98it [00:04, 25.30it/s]
102it [00:04, 22.74it/s]
105it [00:04, 20.59it/s]
108it [00:04, 19.55it/s]
111it [00:04, 16.33it/s]
114it [00:05, 14.57it/s]
121it [00:05, 22.97it/s]
125it [00:05, 14.56it/s]
128it [00:06, 14.43it/s]
131it [00:06, 14.83it/s]
134it [00:06, 16.95it/s]
142it [00:06, 19.96it/s]
146it [00:06, 20.93it/s]
1it [30:53, 1853.77s/it]
149it [30:53, 135.05s/it]
153it [30:53, 95.27s/it] 
159it [30:54, 58.44s/it]
164it [30:54, 40.02s/it]
171it [30:54, 24.55s/it]
175it [30:55, 18.67s/it]
183it [

In [None]:
ocr_dict_sorted = defaultdict(list)
for key in ocr_dict.keys():
    list_of_tuples = ocr_dict[key]
    sorted_list_of_tuples = sorted(list_of_tuples, key=lambda x: x[0])
    ocr_dict_sorted[key] = [t[1] for t in sorted_list_of_tuples]

In [None]:
expert_name_dict = defaultdict(list)
expert_names = list(expert_dict.keys())
for key, val in bbox_label_dict.items():
    expert_name_dict[key] = [expert_names[idx] for idx in val]

# serialize result

In [None]:
df = pd.read_parquet(parquet_fname, engine='pyarrow')

In [None]:
ocr_col = ['']*len(df)
bbox_col = [[]]*len(df)
exp_col = ['']*len(df)

for key in bbox_dict.keys():
    idx = int(key)
    ocr_col[idx] = ocr_dict_sorted[key]
    bbox_col[idx] = bbox_dict[key]
    exp_col[idx] = expert_name_dict[key]
df['OCR_BBOXES'] = bbox_col
df['OCR_EXPERTS'] = exp_col
df['OCR_TEXT'] = ocr_col


In [None]:
df['OCR_BBOXES'] = df['OCR_BBOXES'].astype(str)
df['OCR_EXPERTS'] = df['OCR_EXPERTS'].astype(str)
df['OCR_TEXT'] = df['OCR_TEXT'].astype(str)

In [None]:
import fastparquet as pq
import os
pq.write(os.path.abspath(parquet_result_fname), df)

# eyeball result

In [None]:
from IPython.display import Image, display
from IPython.core.display import HTML 

In [None]:
handwritten_idcs = []
for idx, exps in enumerate(exp_col):
    if f'trocr-{expert_size}-handwritten' in exps:
        handwritten_idcs += [idx]
        display(Image(url=df.iloc[idx]['URL']))
        print(df.iloc[idx]['OCR_TEXT'])
        print(df.iloc[idx]['OCR_EXPERTS'])
        print()
print(handwritten_idcs)

In [None]:
Image(url=df.iloc[idx]['URL'])

In [24]:
df.iloc[idx]['OCR_TEXT']

'[]'

In [25]:
df.iloc[idx]['OCR_EXPERTS']

'[]'

In [26]:
df.iloc[idx]['OCR_BBOXES']

'[]'