In [64]:
import numpy as np
import skimage.transform as sktransform

new_shape = (8, 8)
def scale_image(image):
    return sktransform.resize(image, new_shape, anti_aliasing=True)

def refine_region(region):
    region = np.array(region, dtype=np.float32)
    if min(region.shape) / max(region.shape) < 0.1:
        region = region[0:min(region.shape), 0:min(region.shape)]
    if region.shape[0] > region.shape[1]:
        region = np.pad(region, ((0, 0), (0, region.shape[0] - region.shape[1])), 'constant', constant_values=0)
    elif region.shape[0] < region.shape[1]:
        region = np.pad(region, ((0, region.shape[1] - region.shape[0]), (0, 0)), 'constant', constant_values=0)
    region = scale_image(region)
    region = region / region.max()
    region = np.sqrt(2 * region - region**2)
    region = np.round(region * 255).astype(np.uint8)
    return region

In [65]:
import faiss

index = faiss.IndexFlatL2(64) 
region_list = []

def add_embedding(region):
    if index.ntotal == 0:
        index.add(region)
        region_list.append(region)
    else:
        em = index.search(region, k=1)[1][0][0]
        if np.sum((region - region_list[em]) ** 2) > 0.001:
            index.add(region)
            region_list.append(region)

In [66]:
import cv2
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

def load_image(path):
    return cv2.imread(path)

def load_all_images(image_paths, num_workers=None):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        images = list(tqdm(executor.map(load_image, image_paths), total=len(image_paths)))
    return images

def rgb_to_gray(im):
    return 255 - im[:, :, 0]

directory = 'D:\\formula_images\\formula_images\\'
    
def process_image(image):
    image = rgb_to_gray(image).transpose()
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(image, connectivity=8)
    for label in range(1, num_labels): 
        x, y, w, h, area = stats[label]
        add_embedding(refine_region((labels[y:y+h, x:x+w] == label).astype(np.uint8).transpose()).reshape(1, -1))

In [67]:
max_workers = 24
block_size = 48

def process_data(data, max_data = -1):
    if max_data == -1:
        max_data = len(data)
    for block in range(0, max_data, block_size):
        if block >= max_data:
            break
        images = load_all_images([directory + data['image'][i] for i in range(block, block + block_size) if i < max_data], max_workers)
        for image in images:
            process_image(image)

In [68]:
import pandas

process_data(pandas.read_csv('.\\..\\dataset\\im2latex_train.csv'))

100%|██████████| 48/48 [00:00<00:00, 1062.89it/s]
100%|██████████| 48/48 [00:00<00:00, 1754.17it/s]
100%|██████████| 48/48 [00:00<00:00, 1701.36it/s]
100%|██████████| 48/48 [00:00<00:00, 3436.84it/s]
100%|██████████| 48/48 [00:00<00:00, 843.90it/s]
100%|██████████| 48/48 [00:00<00:00, 907.56it/s]
100%|██████████| 48/48 [00:00<00:00, 1086.34it/s]
100%|██████████| 48/48 [00:00<00:00, 1338.52it/s]
100%|██████████| 48/48 [00:00<00:00, 1342.26it/s]
100%|██████████| 48/48 [00:00<00:00, 1323.41it/s]
100%|██████████| 48/48 [00:00<00:00, 891.12it/s]
100%|██████████| 48/48 [00:00<00:00, 795.90it/s]
100%|██████████| 48/48 [00:00<00:00, 394.36it/s]
100%|██████████| 48/48 [00:00<00:00, 606.14it/s]
100%|██████████| 48/48 [00:00<00:00, 526.63it/s]
100%|██████████| 48/48 [00:00<00:00, 330.52it/s]
100%|██████████| 48/48 [00:00<00:00, 457.42it/s]
100%|██████████| 48/48 [00:00<00:00, 404.77it/s]
100%|██████████| 48/48 [00:00<00:00, 334.10it/s]
100%|██████████| 48/48 [00:00<00:00, 373.16it/s]
100%|███████

In [77]:
import datasets

datasets.Dataset.save_to_disk(
    datasets.Dataset.from_dict({'embeddings': region_list}),
    'embeddings'
)

Saving the dataset (0/1 shards):   0%|          | 0/8115 [00:00<?, ? examples/s]

In [78]:
len(region_list)

8115