In [1]:
import datasets

embeddings = datasets.load_from_disk("embeddings")['embeddings']
labels = datasets.load_from_disk("labels")['labels']

In [2]:
import numpy as np
import skimage.transform as sktransform

new_shape = (8, 8)
def scale_image(image):
    return sktransform.resize(image, new_shape, anti_aliasing=True)

def refine_region(region):
    region = np.array(region, dtype=np.float32)
    if region.shape[0] > region.shape[1]:
        region = np.pad(region, ((0, 0), (0, region.shape[0] - region.shape[1])), 'constant', constant_values=0)
    elif region.shape[0] < region.shape[1]:
        region = np.pad(region, ((0, region.shape[1] - region.shape[0]), (0, 0)), 'constant', constant_values=0)
    region = scale_image(region)
    region = region / region.max()
    region = np.sqrt(2 * region - region**2)
    region = np.round(region * 255).astype(np.uint8)
    return region

In [3]:
import faiss

index = faiss.IndexFlatL2(64) 

for em in embeddings:
    index.add(np.array([em], dtype=np.uint8))
    
def get_embedding(region):
    return index.search(region, k=1)[1][0][0].flatten()

In [4]:
import cv2
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

def load_image(path):
    return cv2.imread(path)

def load_all_images(image_paths, num_workers=None):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        images = list(tqdm(executor.map(load_image, image_paths), total=len(image_paths)))
    return images

def rgb_to_gray(im):
    return 255 - im[:, :, 0]

directory = 'D:\\formula_images\\formula_images\\'
    
def get_image(image):
    image = rgb_to_gray(image).transpose()
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(image, connectivity=8)
    ems = []
    pos = []
    for label in range(1, num_labels): 
        x, y, w, h, area = stats[label]
        pos.append((y, x, y + h - 1, x + w - 1))
        ems.append(get_embedding(refine_region((labels[y:y+h, x:x+w] == label).astype(np.uint8).transpose()).reshape(1, -1)))
    return ems, pos

In [5]:
label_id = {}
for i, label in enumerate(labels):
    label_id[label] = i

def process_formula(formula):
    return [label_id[lb] + 1 for lb in formula.strip().split(' ')]

max_workers = 24
block_size = 48

def process_data(data, max_data = -1):
    ans = []
    if max_data == -1:
        max_data = len(data)
    for block in range(0, max_data, block_size):
        if block >= max_data:
            break
        images = load_all_images([directory + data['image'][i] for i in range(block, block + block_size) if i < max_data], max_workers)
        for i in range(block, block + block_size):
            if i == max_data:
                break
            if type(data['formula'][i]) != str:
                continue
            image = images[i - block]
            formula = process_formula(data['formula'][i])
            ems, pos = get_image(image)
            ans.append({'embeddings': ems, 'pos': pos, 'formula': formula})
    return ans

In [6]:
import pandas

train = process_data(pandas.read_csv('.\\..\\dataset\\im2latex_train.csv'), 10000)
test = process_data(pandas.read_csv('.\\..\\dataset\\im2latex_test.csv'), 3000)
validate = process_data(pandas.read_csv('.\\..\\dataset\\im2latex_validate.csv'), 1000)

100%|██████████| 48/48 [00:00<00:00, 2584.36it/s]
100%|██████████| 48/48 [00:00<00:00, 1179.52it/s]
100%|██████████| 48/48 [00:00<00:00, 852.20it/s]
100%|██████████| 48/48 [00:00<00:00, 804.84it/s]
100%|██████████| 48/48 [00:00<00:00, 1576.24it/s]
100%|██████████| 48/48 [00:00<00:00, 1077.83it/s]
100%|██████████| 48/48 [00:00<00:00, 1417.51it/s]
100%|██████████| 48/48 [00:00<00:00, 1232.74it/s]
100%|██████████| 48/48 [00:00<00:00, 1250.72it/s]
100%|██████████| 48/48 [00:00<00:00, 987.11it/s]
100%|██████████| 48/48 [00:00<00:00, 1486.13it/s]
100%|██████████| 48/48 [00:00<00:00, 1652.26it/s]
100%|██████████| 48/48 [00:00<00:00, 5382.92it/s]
100%|██████████| 48/48 [00:00<00:00, 1029.19it/s]
100%|██████████| 48/48 [00:00<00:00, 789.11it/s]
100%|██████████| 48/48 [00:00<00:00, 926.26it/s]
100%|██████████| 48/48 [00:00<00:00, 740.83it/s]
100%|██████████| 48/48 [00:00<00:00, 1138.06it/s]
100%|██████████| 48/48 [00:00<00:00, 639.74it/s]
100%|██████████| 48/48 [00:00<00:00, 696.29it/s]
100%|███

In [7]:
datasets.Dataset.save_to_disk(
    datasets.Dataset.from_list(train),
    'train_dataset'
)
datasets.Dataset.save_to_disk(
    datasets.Dataset.from_list(test),
    'test_dataset'
)
datasets.Dataset.save_to_disk(
    datasets.Dataset.from_list(validate),
    'validate_dataset'
)

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2976 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]