In [108]:
import datasets

embeddings = datasets.load_from_disk("embeddings")['embeddings']
labels = datasets.load_from_disk("labels")['labels']

In [109]:
import numpy as np
import skimage.transform as sktransform

new_shape = (8, 8)
def scale_image(image):
    return sktransform.resize(image, new_shape, anti_aliasing=True)

def refine_region(region):
    region = np.array(region, dtype=np.float32)
    if min(region.shape) / max(region.shape) < 0.1:
        region = region[0:min(region.shape), 0:min(region.shape)]
    if region.shape[0] > region.shape[1]:
        region = np.pad(region, ((0, 0), (0, region.shape[0] - region.shape[1])), 'constant', constant_values=0)
    elif region.shape[0] < region.shape[1]:
        region = np.pad(region, ((0, region.shape[1] - region.shape[0]), (0, 0)), 'constant', constant_values=0)
    region = scale_image(region)
    region = region / region.max()
    region = np.sqrt(2 * region - region**2)
    region = np.round(region * 255).astype(np.uint8)
    return region

In [110]:
import faiss

index = faiss.IndexFlatL2(64) 

for em in embeddings:
    index.add(np.array(em, dtype=np.uint8))
    
def get_embedding(region):
    return index.search(region, k=1)[1][0][0].flatten()

In [111]:
import cv2
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

def load_image(path):
    return cv2.imread(path)

def load_all_images(image_paths, num_workers=None):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        images = list(tqdm(executor.map(load_image, image_paths), total=len(image_paths)))
    return images

def rgb_to_gray(im):
    return 255 - im[:, :, 0]

directory = 'D:\\formula_images\\formula_images\\'
    
def get_image(image):
    image = rgb_to_gray(image).transpose()
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(image, connectivity=8)
    ems = []
    pos = []
    for label in range(1, num_labels): 
        x, y, w, h, area = stats[label]
        pos.append((y, x, y + h - 1, x + w - 1))
        ems.append(get_embedding(refine_region((labels[y:y+h, x:x+w] == label).astype(np.uint8).transpose()).reshape(1, -1)))
    return ems, pos

def get_all_images(images, num_workers=None):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        ans = list(tqdm(executor.map(get_image, images), total=len(images)))
    return ans

In [112]:
label_id = {}
for i, label in enumerate(labels):
    label_id[label] = i

def process_formula(formula):
    return [label_id[lb] + 1 for lb in formula.strip().split(' ')]

def process_all_formulas(formulas, num_workers=None):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        formula = list(tqdm(executor.map(process_formula, formulas), total=len(formulas)))
    return formula

max_workers = 24
block_size = max_workers * 5

def process_data(data, max_data = -1):
    ans = []
    if max_data == -1:
        max_data = len(data)
    latex_formulas = data['formula']
    image_paths = data['image']
    for block in range(0, max_data, block_size):
        if block >= 1000 and block - block_size < 1000:
            print('Processed:', block)
        if block >= max_data:
            break
        images = load_all_images([directory + image_paths[i] for i in range(block, block + block_size) if i < max_data], max_workers)
        inputs = get_all_images(images, max_workers)
        for i in range(len(images)):
            if type(latex_formulas[block + i]) != str:
                continue
            ans.append({'embeddings': inputs[i][0], 'pos': inputs[i][1], 'formula': process_formula(latex_formulas[block + i])})
    return ans

In [None]:
import pandas

train = process_data(pandas.read_csv('.\\..\\dataset\\im2latex_train.csv'))
test = process_data(pandas.read_csv('.\\..\\dataset\\im2latex_test.csv'))
validate = process_data(pandas.read_csv('.\\..\\dataset\\im2latex_validate.csv'))

In [None]:
datasets.Dataset.save_to_disk(
    datasets.Dataset.from_list(train),
    'train_dataset'
)
datasets.Dataset.save_to_disk(
    datasets.Dataset.from_list(test),
    'test_dataset'
)
datasets.Dataset.save_to_disk(
    datasets.Dataset.from_list(validate),
    'validate_dataset'
)

Saving the dataset (0/1 shards):   0%|          | 0/10284 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8370 [00:00<?, ? examples/s]