In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ViTModel, BertTokenizer, BertConfig
import os
import xml.etree.ElementTree as ET
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from collections import defaultdict


In [None]:
train_img_base_dir = 'dataset/crohme/train/extracted_img'
# os.path.exists(train_img_base_dir)
os.listdir(train_img_base_dir)

## 1. 数据集准备

In [None]:
def create_combined_csv(inkml_folder, latex_folder, img_folder, output_csv):
    data = []
    
    # 遍历 INKML 文件夹，建立文件路径的索引
    inkml_file_paths = {}
    
    for root, _, files in os.walk(inkml_folder):
        for file in files:
            if file.endswith('.inkml'):
                inkml_file_paths[file] = os.path.join(root, file)
    
    # 遍历 LaTeX 表达式文件夹
    for root, _, files in os.walk(latex_folder):
        for file in files:
            if file.endswith('.lg'):  # 假设 LaTeX 表达式文件的后缀为 .lg
                latex_file_path = os.path.join(root, file)
                
                with open(latex_file_path, 'r', encoding='utf-8') as f:
                    latex_lines = f.readlines()
                    
                    for inkml_file_name, latex in zip(inkml_file_paths.keys(), latex_lines):
                        latex = latex.strip()  # 去除空白符
                        
                        # 找到已经转换的 PNG 图像文件
                        image_file_name = os.path.basename(inkml_file_name).replace('.inkml', '.png')
                        image_file_path = os.path.join(img_folder, image_file_name)
                        
                        if os.path.exists(image_file_path):
                            data.append([image_file_path, latex])
                        else:
                            print(f"Warning: Image file not found for {image_file_path}")

    # 将结果保存为 CSV 文件
    if data:
        df = pd.DataFrame(data, columns=['image_path', 'latex'])
        df.to_csv(output_csv, index=False)
        print(f"CSV 文件已创建，路径为: {output_csv}")
    else:
        print("No matching data found. Please check your file paths.")

# 使用文件夹路径来调用函数
inkml_folder = 'TC11_CROHME23/INKML/train/CROHME2019'  # InkML 文件夹路径
latex_folder = 'TC11_CROHME23/SymLG/train/CROHME2019_train'  # LaTeX 表达式文件夹路径
img_folder = 'TC11_CROHME23/IMG/train/CROHME2019'  # 已转换的 PNG 图像文件存储路径
output_csv = 'crohme_labels.csv'  # 输出 CSV 文件路径

create_combined_csv(inkml_folder, latex_folder, img_folder, output_csv)

## 2.  数据加载

In [24]:
" 定义数据集与数据加载器 "
class CROHMEDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]
        latex_expr = self.data.iloc[idx, 1]
        
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
        
        return image, latex_expr

In [25]:
class ViTEncoder(nn.Module):
    def __init__(self, pretrained_model_name='google/vit-base-patch16-224-in21k'):
        super(ViTEncoder, self).__init__()
        self.encoder = ViTModel.from_pretrained(pretrained_model_name)
        
    def forward(self, x):
        outputs = self.encoder(pixel_values=x)
        return outputs.last_hidden_state

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 500, hidden_dim))
        self.decoder_layer = nn.TransformerDecoderLayer(hidden_dim, num_heads)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, encoder_outputs, tgt, tgt_mask):
        tgt_embedded = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]
        outputs = self.transformer_decoder(tgt_embedded, encoder_outputs, tgt_mask=tgt_mask)
        outputs = self.fc_out(outputs)
        return outputs

# 组合模型
class ImageToLatexModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim=512, num_layers=6, num_heads=8):
        super(ImageToLatexModel, self).__init__()
        self.encoder = ViTEncoder()
        self.decoder = TransformerDecoder(vocab_size, hidden_dim, num_layers, num_heads)

    def forward(self, x, tgt, tgt_mask):
        encoder_outputs = self.encoder(x)
        outputs = self.decoder(encoder_outputs, tgt, tgt_mask)
        return outputs

In [26]:

def train_one_epoch(model, train_loader, optimizer, criterion):
    model.train()
    epoch_loss = 0
    for images, latex_exprs in train_loader:
        images = images.to(device)
        latex_exprs = latex_exprs.to(device)  # 假设已编码为张量
        optimizer.zero_grad()

        tgt_input = latex_exprs[:, :-1]
        tgt_output = latex_exprs[:, 1:]

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
        output = model(images, tgt_input, tgt_mask)

        loss = criterion(output.view(-1, output.size(-1)), tgt_output.view(-1))
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(train_loader)

In [None]:
if __name__ == "__main__":

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    train_dataset = CROHMEDataset('crohme_labels.csv', transform=transform)
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    # 假设我们已经准备好词汇表并对目标文本进行了编码
    vocab_size = 10000  # 假设词汇表大小
    model = ImageToLatexModel(vocab_size).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=0)


    # 开始训练
    num_epochs = 10
    for epoch in range(num_epochs):
        epoch_loss = train_one_epoch(model, train_dataloader, optimizer, criterion)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')


In [None]:
def generate_latex(model, image, tokenizer, max_length=100):
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        encoder_outputs = model.encoder(image)

        tgt_input = torch.tensor([[tokenizer.cls_token_id]], device=device)

        for _ in range(max_length):
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            output = model.decoder(encoder_outputs, tgt_input, tgt_mask)
            
            next_token = output.argmax(2)[:, -1].unsqueeze(0)
            tgt_input = torch.cat([tgt_input, next_token], dim=1)

            if next_token.item() == tokenizer.sep_token_id:
                break

    return tokenizer.decode(tgt_input.squeeze().tolist(), skip_special_tokens=True)

# 使用 BertTokenizer 进行示例推理
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
test_image, _ = train_dataset[0]  # 获取测试图像

latex_expr = generate_latex(model, test_image, tokenizer)
print(latex_expr)

In [None]:
def save_to_ig_format(latex_expr, output_path):
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(latex_expr)

# 示例保存
output_path = 'output/formula.ig'
save_to_ig_format(latex_expr, output_path)

# Older Version

In [None]:
" 构建 LaTeX 表达式词汇表与编码函数 "
# 构建词汇表
vocab = defaultdict(int)

# 确保遍历 LaTeX 表达式所在的列
for _, row in dataset.data.iterrows():
    latex = row['latex']  # 获取 'latex' 列
    for token in latex.split():
        vocab[token] += 1

# 分配索引：<pad>, <bos>, <eos> 分别为填充、起始和结束标记
vocab = {token: idx + 3 for idx, (token, _) in enumerate(sorted(vocab.items(), key=lambda x: -x[1]))}
vocab['<pad>'] = 0
vocab['<bos>'] = 1
vocab['<eos>'] = 2

# 构建反向词汇表
idx2vocab = {idx: token for token, idx in vocab.items()}

# 定义编码与解码函数
def encode_latex(latex):
    return [vocab['<bos>']] + [vocab.get(token, vocab['<pad>']) for token in latex.split()] + [vocab['<eos>']]

def decode_latex(indices):
    return ' '.join(idx2vocab[idx] for idx in indices if idx not in [0, 1, 2])


## 3. 模型架构

In [None]:
" 定义 ViT + Transformer 模型 "
import torch
import torch.nn as nn
from transformers import ViTModel

class ViTTransformerModel(nn.Module):
    def __init__(self, vit_model_name='google/vit-base-patch16-224', decoder_vocab_size=3000, d_model=768, num_decoder_layers=6):
        super(ViTTransformerModel, self).__init__()
        
        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.d_model = d_model
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=8)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        
        self.embedding = nn.Embedding(decoder_vocab_size, d_model)
        self.fc_out = nn.Linear(d_model, decoder_vocab_size)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 500, d_model))

    def forward(self, images, target_seq, target_mask):
        encoded_features = self.vit(images).last_hidden_state
        memory = encoded_features.permute(1, 0, 2)
        
        tgt_embedded = self.embedding(target_seq) + self.positional_encoding[:, :target_seq.size(1), :]
        tgt_embedded = tgt_embedded.permute(1, 0, 2)
        
        output = self.transformer_decoder(tgt_embedded, memory, tgt_mask=target_mask)
        output = self.fc_out(output)
        
        return output.permute(1, 0, 2)

In [None]:
" 训练模型 "
model = ViTTransformerModel(decoder_vocab_size=len(vocab))
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, target_seq in dataloader:
        optimizer.zero_grad()
        
        target_input = torch.tensor([([vocab['<bos>']] + encode_latex(seq)[:-1]) for seq in target_seq], dtype=torch.long)
        target_output = torch.tensor([encode_latex(seq) + [vocab['<eos>']] for seq in target_seq], dtype=torch.long)
        
        target_mask = generate_square_subsequent_mask(target_input.size(1)).to(images.device)
        outputs = model(images, target_input, target_mask)
        
        loss = criterion(outputs.view(-1, outputs.size(-1)), target_output.view(-1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(dataloader)}")


## 推理

In [None]:
def predict_latex(model, image, start_token, end_token, max_length=500):
    model.eval()
    target_seq = torch.LongTensor([[start_token]]).to(image.device)
    
    with torch.no_grad():
        for _ in range(max_length):
            target_mask = generate_square_subsequent_mask(target_seq.size(1)).to(image.device)
            output = model(image, target_seq, target_mask)
            
            next_token = output.argmax(-1)[:, -1]
            target_seq = torch.cat([target_seq, next_token.unsqueeze(1)], dim=1)
            
            if next_token.item() == end_token:
                break
                
    return decode_latex(target_seq.squeeze().tolist())

# Example usage for one image
image, _ = dataset[0]
image = image.unsqueeze(0)  # Add batch dimension
predicted_latex = predict_latex(model, image, vocab['<bos>'], vocab['<eos>'])
print(predicted_latex)

In [10]:
import pickle
import os
import cv2

def extract_img(pkl_file: str, save_dir: str):
    """
    Extracts images from a pickle file and saves them as image files in the specified directory.
    
    Parameters:
        pkl_file (str): Path to the pickle file containing image data.
        save_dir (str): Directory where the extracted images will be saved.
    """
    # Load the pickle file
    with open(pkl_file, 'rb') as f:
        data = pickle.load(f)
    
    # Ensure the output directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    # Iterate through the dictionary items
    for image_name, image_array in data.items():
        # Ensure image_name has a proper extension
        if not image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_name += '.png'  # Default to .png if no extension
        
        # Construct the full path for saving the image
        output_path = os.path.join(save_dir, image_name)
        
        # Convert image_array to an appropriate data type if needed
        if image_array.dtype != 'uint8':
            image_array = image_array.astype('uint8')
        
        # Save the image using OpenCV
        success = cv2.imwrite(output_path, image_array)
        if success:
            print(f"Saved image: {output_path}")
        else:
            print(f"Failed to save image: {output_path}")

In [16]:

" Merge image and labels to a csv "
def merge_hmer_or_crohme(images: str, caption: str, extract_img_dir: str, save_file: str):
    """
    Optimized function to read images from a pickle file and labels from a text file,
    then saves the matched image paths with their corresponding LaTeX labels into a CSV file.

    Parameters:
        images (str): Path to the pickle (.pkl) file containing image paths.
        caption (str): Path to the caption text file containing image-label pairs.
        save_file (str): Path to the output CSV file where the results will be saved.
    """
    
    # Load the image paths from the pickle file
    with open(images, 'rb') as f:
        image_list = pickle.load(f)
    """
    Load image paths into `image_list` from the pickle file.
    """

    # Create a dictionary for fast lookup: key is the image filename, value is the full path
    image_dict = {os.path.basename(img_path): img_path for img_path in image_list}
    """
    The `image_dict` allows for O(1) average-time complexity lookups for matching images.
    """

    # Read the caption file
    with open(caption, 'r', encoding='utf-8') as f:
        data = f.readlines()
    """
    Read all lines from the caption.txt file.
    """
    
    img_labels = []

    # Process each line in the caption file
    for line in data:
        """
        Loop through each line in the caption file.
        """
        
        # Strip any leading/trailing whitespace/newlines
        line = line.strip()
        
        # Check for tab separation and split accordingly
        if '\t' in line:
            image_name, label = line.split('\t', 1)  # Split by the first tab
        else:
            # Skip if improperly formatted
            continue
        
        # Use the dictionary to find the matching image path
        if image_name in image_dict:
            # image_path = (extract_img_dir + '/' + image_dict[image_name])  # if extract_img_dir not None else image_dict[image_name]
            image_path = os.path.join(extract_img_dir, image_dict[image_name] + '.png')
            print(image_path)
            if os.path.isfile(image_path):
                img_labels.append([image_path, label])
            else:
                print(f"Warning: Image {image_name} not found in directory {extract_img_dir}. Skipping.")
        else:
            print(f"Warning: Image {image_name} not found in the pickle file. Skipping.")
    """
    The loop now efficiently finds image paths using dictionary lookups instead of iterating through a list.
    """

    # Create a DataFrame from the list of image-label pairs
    df = pd.DataFrame(img_labels, columns=['Image Path', 'Label'])

    # Save the DataFrame to a CSV file without index
    df.to_csv(save_file, index=False)
    print(f"CSV file saved as {save_file}")

In [20]:
training_img_pkl_dir = 'dataset/crohme/train/images.pkl'
train_img_base_dir = 'dataset/crohme/train/extracted_img'
caption_dir = 'dataset/crohme/train/caption.txt'
mapping_csv = 'dataset/crohme/train/crohme_labels.csv'

# extract_img(pkl_file=training_img_pkl_dir, save_dir=train_img_base_dir)
# train_img_base_dir is empty
if not os.path.exists(train_img_base_dir) or os.path.exists(train_img_base_dir) and not os.listdir(train_img_base_dir):
    print('Extracting images...')

Extracting images...


In [17]:
merge_hmer_or_crohme(images=training_img_pkl_dir, caption=caption_dir, extract_img_dir=train_img_base_dir, save_file=mapping_csv)

dataset/crohme/train/extracted_img/200924-1331-216.png
dataset/crohme/train/extracted_img/200923-131-185.png
dataset/crohme/train/extracted_img/200923-1553-117.png
dataset/crohme/train/extracted_img/200923-1251-17.png
dataset/crohme/train/extracted_img/200923-1556-256.png
dataset/crohme/train/extracted_img/200923-1254-165.png
dataset/crohme/train/extracted_img/200923-1254-243.png
dataset/crohme/train/extracted_img/200923-1254-166.png
dataset/crohme/train/extracted_img/200923-1553-144.png
dataset/crohme/train/extracted_img/200923-1251-111.png
dataset/crohme/train/extracted_img/200926-131-59.png
dataset/crohme/train/extracted_img/2009212-952-76.png
dataset/crohme/train/extracted_img/200923-1251-203.png
dataset/crohme/train/extracted_img/200924-1331-58.png
dataset/crohme/train/extracted_img/200923-131-63.png
dataset/crohme/train/extracted_img/200923-1253-130.png
dataset/crohme/train/extracted_img/2009213-139-16.png
dataset/crohme/train/extracted_img/2009213-139-221.png
dataset/crohme/trai