In [1]:
from PIL import Image
import re
import numpy as np
from IPython.display import display, Math, Latex

In [2]:
def crop_to_formula(image, padding = 30):
    # Image: 4 channel image with alpha.
    # Convert black pixels to white pixels.
    data = np.array(image)
    red, green, blue, alpha = data.T
    black_areas = (red < 10) & (blue < 10) & (green < 10)
    # Convert alpha to white.
    data[..., -1] = 255
    # Crop a box around the area that contains black pixels.
    coords = np.argwhere(black_areas)
    x0, y0 = coords.min(axis=0)
    x1, y1 = coords.max(axis=0) + 1
    # Add padding.
    x0 = max(0, x0 - padding)
    y0 = max(0, y0 - padding)
    x1 = min(image.width, x1 + padding)
    y1 = min(image.height, y1 + padding)
    image = Image.fromarray(data[y0:y1, x0:x1])
    return image.convert('RGB')

In [3]:
def renderedLaTeXLabelstr2Formula(label: str):
    # We're matching \\label{...whatever} and removing it
    label = re.sub(r"\\label\{[^\}]*\}", "", label)
    # We match \, and remove it.
    label = re.sub(r"\\,", "", label)
    return label

In [4]:
def display_formula(latex: str):
    # Remove \mbox{...} - not supported by the inline MathJax renderer
    parsed_latex = re.sub(r"\\mbox\{[^\}]*\}", "", latex)
    display(Math(parsed_latex))

In [14]:
import torch as t
import requests
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torchvision.transforms.v2
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
try:
    args = parser.parse_args()
except SystemExit:
    args = argparse.Namespace(gpu=0)  # Default to 0 if running in Jupyter


usage: ipykernel_launcher.py [-h] [--gpu GPU]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\prana\AppData\Roaming\jupyter\runtime\kernel-021287d3-cbdb-445a-ba44-1ffaace27713.json


In [10]:
device = t.device('cuda:{}'.format(args.gpu) if t.cuda.is_available() else 'cpu')

In [11]:
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-large-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Load LST files
import pandas as pd
import numpy as np
import re
from tqdm import tqdm, trange

In [13]:
import sys, os
sys.path.append(os.path.abspath('../'))
# from utils.latex import crop_to_formula, renderedLaTeXLabelstr2Formula, display_formula

In [7]:
import pandas as pd

In [15]:
train_filenames_df = pd.read_csv(r"C:\Users\prana\Downloads\56198\im2latex_train.lst", index_col = 0, header = None, sep = " ")
val_filenames_df = pd.read_csv(r"C:\Users\prana\Downloads\56198\im2latex_validate.lst", index_col = 0, header = None, sep = " ")
formulas = open(r"C:\Users\prana\Downloads\56198\im2latex_formulas.lst", encoding = "ISO-8859-1", newline="\n").readlines()

In [16]:
print("Number of training formulas: ", len(train_filenames_df))
print("Number of validation formulas: ", len(val_filenames_df))

Number of training formulas:  83884
Number of validation formulas:  9320


In [17]:
max_len = max([len(formula) for formula in formulas])
print("Max length:", max_len)

Max length: 998


In [18]:
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
import numpy as np
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split
import sklearn as skl
from data.datasets import renderedLaTeXDataset, set_seed
from data.dataset_tests import test_renderedLaTeXDataset
    
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

ModuleNotFoundError: No module named 'utils.latex'

In [19]:
# Hyperparams
NUM_EPOCHS = 2
LEARNING_RATE = 1e-5
BATCH_SIZE = 4 # 10 gigs of Vram -> 4, <5 gigs of vram -> 2
SHUFFLE_DATASET = True

In [21]:
import random
import numpy as np
import torch as t

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    t.manual_seed(seed)
    if t.cuda.is_available():
        t.cuda.manual_seed_all(seed)  # If using CUDA


In [25]:
# from data.datasets import renderedLaTeXDataset, set_seed
# from data.dataset_tests import test_renderedLaTeXDataset
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import os, sys
sys.path.append(os.path.abspath('../'))
# from utils.latex import crop_to_formula, renderedLaTeXLabelstr2Formula, display_formula

from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split
import sklearn as skl
import torch as t

class renderedLaTeXDataset(Dataset):
    def __init__(self, image_folder, lst_file, formulas_file, processor, device, cutoff = None):
        self.image_folder = image_folder
        self.lst_file = lst_file
        self.formulas_file = formulas_file
        self.train_filenames_df = pd.read_csv(self.lst_file, sep=" ", index_col = 0, header = None)
        self.formulas = open(self.formulas_file, encoding = "ISO-8859-1", newline="\n").readlines()
        self.processor = processor
        self.device = device
        self.cutoff = cutoff if cutoff else len(self.train_filenames_df)
        if cutoff is not None:
            self.train_filenames_df = self.train_filenames_df.iloc[:self.cutoff]
            self.formulas = self.formulas[:self.cutoff]
        
    def __len__(self):
        return self.cutoff
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.train_filenames_df.iloc[idx, 0] + ".png")
        image = Image.open(img_name).convert('RGBA')
        image = crop_to_formula(image)
        inputs = self.processor(images = image,  padding = "max_length", return_tensors="pt").to(self.device)
        for key in inputs:
            inputs[key] = inputs[key].squeeze() # Get rid of batch dimension since the dataloader will batch it for us.

        formula_idx = self.train_filenames_df.index[idx]
        caption = renderedLaTeXLabelstr2Formula(self.formulas[formula_idx])
        caption = self.processor.tokenizer.encode(
            caption, return_tensors="pt", padding = "max_length", max_length = 512, truncation = True, # Tweak this
            ).to(self.device).squeeze()
        
        return inputs, caption
    
def set_seed(seed):
    np.random.seed(seed)
    t.manual_seed(seed)
    if t.cuda.is_available():
        t.cuda.manual_seed_all(seed)
    skl.utils.check_random_state(seed)

In [26]:
# Credit goes to https://www.kaggle.com/code/kalikichandu/preprossing-inkml-to-png-files for original code.

import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.transform import resize
import xml.etree.ElementTree as ET
import os
import numpy as np
from tqdm import tqdm
import cv2
import collections

def get_traces_data(inkml_file_abs_path):
    

    traces_data = []
    
    tree = ET.parse(inkml_file_abs_path)
    root = tree.getroot()
    doc_namespace = "{http://www.w3.org/2003/InkML}"

#   'Stores traces_all with their corresponding id'
    traces_all = [{'id': trace_tag.get('id'),
    					'coords': [[round(float(axis_coord)) if float(axis_coord).is_integer() else round(float(axis_coord) * 10000) \
    									for axis_coord in coord[1:].split(' ')] if coord.startswith(' ') \
    								else [round(float(axis_coord)) if float(axis_coord).is_integer() else round(float(axis_coord) * 10000) \
    									for axis_coord in coord.split(' ')] \
    							for coord in (trace_tag.text).replace('\n', '').split(',')]} \
    							for trace_tag in root.findall(doc_namespace + 'trace')]

#   'Sort traces_all list by id to make searching for references faster'
    traces_all.sort(key=lambda trace_dict: int(trace_dict['id']))

#   'Always 1st traceGroup is a redundant wrapper'
    traceGroupWrapper = root.find(doc_namespace + 'traceGroup')

    if traceGroupWrapper is not None:
        for traceGroup in traceGroupWrapper.findall(doc_namespace + 'traceGroup'):

            label = traceGroup.find(doc_namespace + 'annotation').text

#    'traces of the current traceGroup'
            traces_curr = []
            for traceView in traceGroup.findall(doc_namespace + 'traceView'):

#     'Id reference to specific trace tag corresponding to currently considered label'
                traceDataRef = int(traceView.get('traceDataRef'))

#     'Each trace is represented by a list of coordinates to connect'
                single_trace = traces_all[traceDataRef]['coords']
                traces_curr.append(single_trace)

            traces_data.append({'label': label, 'trace_group': traces_curr})

    else:
#             'Consider Validation data that has no labels'
        [traces_data.append({'trace_group': [trace['coords']]}) for trace in traces_all]

    return traces_data

def get_gt(inkml_file_abs_path):
    tree = ET.parse(inkml_file_abs_path)
    root = tree.getroot()
    doc_namespace = "{http://www.w3.org/2003/InkML}"
    annotation = root.find(f".//{doc_namespace}annotation[@type='truth']")
    if annotation is not None:
        truth = annotation.text
    else: raise Exception("No truth annotation found.")
    return truth
def inkml2img(input_path, output_path):
    traces = get_traces_data(input_path)
    if not traces:
        print(f"No traces found for {input_path}.")
        return  # Exit if no traces found

    path = os.path.basename(input_path).split('.')[0] + '_'
    file_name = 0
    plt.axis('off')
    plt.gca().invert_yaxis()
    plt.gca().set_aspect('equal', adjustable='box')
    plt.gca().set_xticks([])
    plt.gca().set_yticks([])

    for elem in traces:
        ls = elem['trace_group']
        for subls in ls:
            data = np.array(subls)
            if data.shape[1] > 2:
                data = data[:, :2]
            x, y = zip(*data)
            plt.plot(x, y, linewidth=2, c='black')

    try:
        os.makedirs(output_path, exist_ok=True)  # Create directory if it doesn't exist
    except OSError as e:
        print(f"Error creating directory {output_path}: {e}")
        return  # Exit on error

    output_file = os.path.join(output_path, f"{path}{file_name}.png")
    plt.savefig(output_file, bbox_inches='tight', dpi=100)
    print(f"Saved image: {output_file}")  # Debug info for saved files
    plt.gcf().clear()



# def inkml2img(input_path, output_path):
# #     print(input_path)
# #     print(pwd)
#     traces = get_traces_data(input_path)
# #     print(traces)
#     path = input_path.split('/')
#     path = path[len(path)-1].split('.')
#     path = path[0]+'_'
#     file_name = 0
#     # Get rid of all matplotlib elements
#     plt.axis('off')
#     # plt.gca().set_position([0, 0, 1, 1])
#     plt.gca().invert_yaxis()
#     plt.gca().set_aspect('equal', adjustable='box')
#     # plt.gca().set_axis_off()
#     plt.gca().set_xticks([])
#     plt.gca().set_yticks([])
    
#     for elem in traces:
        
# #         print(elem)
# #         print('-------------------------')
# #         print(elem['label'])
#         ls = elem['trace_group']
#         output_path = output_path  
        
#         for subls in ls:
# #             print(subls)
            
#             data = np.array(subls)
#             # raise Exception(data)
#             if data.shape[1] > 2:
#                 data = data[:, :2]
#             x,y=zip(*data)
#             plt.plot(x,y,linewidth=2,c='black')
#     try:
#         os.mkdir(output_path)
#     except OSError:
# #             print ("Folder %s Already Exists" % ind_output_path)
# #             print(OSError.strerror)
#         pass
#     else:
# #             print ("Successfully created the directory %s " % ind_output_path)
#         pass
# #         print(ind_output_path+'/'+path+str(file_name)+'.png')
    input_path_safe = input_path.replace('/', '_') + '_'
    if(os.path.isfile(output_path+'/'+input_path_safe+str(file_name)+'.png')):
        # print('1111')
        file_name += 1
        plt.savefig(output_path+'/'+input_path_safe+str(file_name)+'.png', bbox_inches='tight', dpi=100)
    else:
        plt.savefig(output_path+'/'+input_path_safe+str(file_name)+'.png', bbox_inches='tight', dpi=100)
    plt.gcf().clear()

def ink2img_folder(input_paths, output_path):
    labels = collections.defaultdict(list)
    for input_path in input_paths:
        input_path_safe = input_path.replace('/', '_') + '_'
        files = os.listdir(input_path)
        # ignore all files that don't have the .inkML extension
        files = [file for file in files if file.endswith('.inkml')]
        for file in tqdm(files):
        #     print(file)
            if output_path[-1] != "/": output_path += "/"
            inkML_path = os.path.join(input_path, file)
            try: 
                labels["label"].append(get_gt(inkML_path))
                labels["name"].append((input_path_safe+file).replace('.','_').replace('_inkml', '.inkml')+'_0.png')
                inkml2img(inkML_path, output_path)
            except: print("Error with file: " + str(file) + " in folder: " + str(input_path) + ". Don't worry, this is expected (though there should only be max 2 or 3!).")
    pd.DataFrame(labels).to_csv(output_path + "labels.csv", index=False)

In [27]:
from PIL import Image
import re
import numpy as np
from IPython.display import display, Math, Latex

def crop_to_formula(image, padding = 30):
    # Image: 4 channel image with alpha.
    # Convert black pixels to white pixels.
    data = np.array(image)
    red, green, blue, alpha = data.T
    black_areas = (red < 10) & (blue < 10) & (green < 10)
    # Convert alpha to white.
    data[..., -1] = 255
    # Crop a box around the area that contains black pixels.
    coords = np.argwhere(black_areas)
    x0, y0 = coords.min(axis=0)
    x1, y1 = coords.max(axis=0) + 1
    # Add padding.
    x0 = max(0, x0 - padding)
    y0 = max(0, y0 - padding)
    x1 = min(image.width, x1 + padding)
    y1 = min(image.height, y1 + padding)
    image = Image.fromarray(data[y0:y1, x0:x1])
    return image.convert('RGB')

def renderedLaTeXLabelstr2Formula(label: str):
    # We're matching \\label{...whatever} and removing it
    label = re.sub(r"\\label\{[^\}]*\}", "", label)
    # We match \, and remove it.
    label = re.sub(r"\\,", "", label)
    return label

def display_formula(latex: str):
    # Remove \mbox{...} - not supported by the inline MathJax renderer
    parsed_latex = re.sub(r"\\mbox\{[^\}]*\}", "", latex)
    display(Math(parsed_latex))

In [29]:
import os, sys
sys.path.append(os.path.abspath('../'))


def test_renderedLaTeXDataset(dataset, processor):
    
    iter_ = iter(dataset)
    inputs, captions = next(iter_)
    inputs_2, captions_2 = next(iter_)
    assert ''.join(processor.batch_decode(captions, skip_special_tokens=True)) != ''.join(processor.batch_decode(captions_2, skip_special_tokens=True)), "Passed dataset yields repeat captions."

    print("renderedLaTeXDataset tests passed.")

In [30]:
set_seed(0)
optimizer = t.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
train_transforms = transforms.Compose([
    transforms.v2.RandomAffine(degrees = 5,
                               scale = (0.7, 1.1),
                               shear = 30),
    transforms.v2.ColorJitter(brightness = 0.2,
                              contrast = 0.2,
                              saturation = 0.2,
                              hue = 0.1)
])

train_ds = renderedLaTeXDataset(image_folder = r"C:\Users\prana\Downloads\56198\formula_images\formula_images", 
                                lst_file = r"C:\Users\prana\Downloads\56198\im2latex_train.lst", 
                                formulas_file = r"C:\Users\prana\Downloads\56198\im2latex_formulas.lst", 

                                device = device,
                                processor = processor,
                            )
val_ds = renderedLaTeXDataset(image_folder = r"C:\Users\prana\Downloads\56198\formula_images\formula_images",
                                lst_file = r"C:\Users\prana\Downloads\56198\im2latex_validate.lst",
                                formulas_file = r"C:\Users\prana\Downloads\56198\im2latex_formulas.lst",
                                device = device,
                                processor = processor)
train_dl = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = SHUFFLE_DATASET, num_workers = 0)
val_dl = DataLoader(val_ds, batch_size = BATCH_SIZE, shuffle = False, num_workers = 0)

test_renderedLaTeXDataset(train_ds, processor = processor)
test_renderedLaTeXDataset(val_ds, processor = processor)

print("Number of training samples:", len(train_ds))
print("Number of validation samples:", len(val_ds))

  return self.preprocess(images, **kwargs)


renderedLaTeXDataset tests passed.
renderedLaTeXDataset tests passed.
Number of training samples: 83884
Number of validation samples: 9320


In [31]:
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.to(device)
model.train()

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-23): 24 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=False)
              (key): Linear(in_features=1024, out_features=1024, bias=False)
              (value): Linear(in_features=1024, out_features=1024, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Line

In [32]:
history = []; val_history = []; val_timesteps = []
ema_loss = None; ema_alpha = 0.95
scaler = t.cuda.amp.GradScaler(enabled = True)
for epoch in range(NUM_EPOCHS):
    with tqdm(train_dl, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}") as pbar:
        for batch, captions in pbar:
            pixel_values = batch["pixel_values"]
            
            optimizer.zero_grad()
            with t.autocast(device_type = "cuda", dtype = t.float16, enabled = True):
                outputs = model(pixel_values = pixel_values,
                                labels = captions)
                loss = outputs.loss
                history.append(loss.item())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if ema_loss is None: ema_loss = loss.item()
            else: ema_loss = ema_loss * ema_alpha + loss.item() * (1 - ema_alpha)
            pbar.set_postfix(loss=ema_loss)
    
    model.eval()
    with t.no_grad():
        val_losses = []
        for batch, captions in tqdm(val_dl):
            pixel_values = batch["pixel_values"]
            outputs = model(pixel_values = pixel_values,
                            labels = captions)
            val_losses.append(outputs.loss.item())
        print(f"Validation loss: {np.mean(val_losses)}")
        val_history.append(np.mean(val_losses))
        val_timesteps.append(len(history) - 1)

Epoch 1/2:   0%|                                                                             | 0/20971 [00:14<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 38.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 10.40 GiB is allocated by PyTorch, and 237.83 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Save model
model.save_pretrained(r"C:/Users/prana/Downloads/Models/trocr-large-rendered-im2latex")
processor.save_pretrained(r"C:/Users/prana/Downloads/Models/trocr-large-rendered-im2latex")
t.save(history, r"C:/Users/prana/Downloads/Models/trocr-large-rendered-im2latex/history.pt")
t.save(val_history, r"C:/Users/prana/Downloads/Models/trocr-large-rendered-im2latex/val_history.pt")
t.save(val_timesteps, r"C:/Users/prana/Downloads/Models/val_timesteps.pt")