In [11]:
import os
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from tqdm import tqdm

def inkml_to_image(inkml_file, img_size=128, line_width=3):
    tree = ET.parse(inkml_file)
    root = tree.getroot()

    # Extract traces (each stroke is a list of x,y points)
    ns = {'ink': 'http://www.w3.org/2003/InkML'}
    traces = []
    for trace in root.findall('ink:trace', ns):
        coords = trace.text.strip().split(',')
        stroke = []
        for c in coords:
            points = c.strip().split(' ')
            if len(points) >= 2:
                x, y = map(float, points[:2])
                stroke.append((x, -y))  # flip y
        traces.append(stroke)

    # Render strokes into an image
    fig, ax = plt.subplots(figsize=(1,1), dpi=img_size)
    ax.axis("off")
    for stroke in traces:
        if len(stroke) > 1:
            xs, ys = zip(*stroke)
            ax.plot(xs, ys, linewidth=line_width, color="black")
    fig.canvas.draw()

    img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    return Image.fromarray(img).convert("L")


In [12]:
def get_latex_label(inkml_file):
    tree = ET.parse(inkml_file)
    root = tree.getroot()
    for ann in root.findall('annotation'):
        if ann.attrib.get('type') == 'normalizedLabel':
            return ann.text
    return None


In [13]:
import csv

def preprocess_split(input_dir, output_dir, split_name, limit=None):
    os.makedirs(output_dir, exist_ok=True)
    csv_path = os.path.join(output_dir, f"{split_name}_labels.csv")

    inkml_files = [f for f in os.listdir(input_dir) if f.endswith(".inkml")]
    if limit:
        inkml_files = inkml_files[:limit]

    with open(csv_path, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["image", "label"])

        for i, fname in enumerate(tqdm(inkml_files, desc=f"Processing {split_name}")):
            inkml_path = os.path.join(input_dir, fname)
            label = get_latex_label(inkml_path)
            if label is None:
                continue

            img = inkml_to_image(inkml_path, img_size=128)
            img_name = f"{split_name}_{i}.png"
            img_path = os.path.join(output_dir, img_name)
            img.save(img_path)

            writer.writerow([img_name, label])

    print(f"✅ {split_name} done → {len(inkml_files)} samples saved in {output_dir}")


In [14]:
DATASET_ROOT = "mathwriting-2024"
OUTPUT_ROOT = "processed_mathwriting"

# Preprocess subsets (limit=500 for quick test, remove later)
preprocess_split(os.path.join(DATASET_ROOT, "train"), OUTPUT_ROOT, "train", limit=500)
preprocess_split(os.path.join(DATASET_ROOT, "valid"), OUTPUT_ROOT, "valid", limit=100)
preprocess_split(os.path.join(DATASET_ROOT, "test"), OUTPUT_ROOT, "test", limit=100)


Processing train: 100%|██████████| 500/500 [00:12<00:00, 40.74it/s]


✅ train done → 500 samples saved in processed_mathwriting


Processing valid: 100%|██████████| 100/100 [00:02<00:00, 42.50it/s]


✅ valid done → 100 samples saved in processed_mathwriting


Processing test: 100%|██████████| 100/100 [00:02<00:00, 41.46it/s]

✅ test done → 100 samples saved in processed_mathwriting





In [2]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Embedding, TimeDistributed, Masking
from tensorflow.keras.optimizers import Adam
import numpy as np

# -------------------------------
# 1. Prepare Input Data
# -------------------------------
# Convert traces into sequences of (dx, dy, pen_state)
def prepare_stroke_sequences(traces, max_len=200):
    sequences = []
    for sample in traces:
        seq = []
        for trace in sample:
            for i in range(1, len(trace)):
                dx = trace[i][0] - trace[i-1][0]
                dy = trace[i][1] - trace[i-1][1]
                pen_state = 1 if i == len(trace)-1 else 0  # pen-up at end of stroke
                seq.append([dx, dy, pen_state])
        seq = np.array(seq)
        if len(seq) > max_len:
            seq = seq[:max_len]
        else:
            pad = np.zeros((max_len - len(seq), 3))
            seq = np.vstack([seq, pad])
        sequences.append(seq)
    return np.array(sequences, dtype=np.float32)

# Example conversion
X_train = prepare_stroke_sequences(train_data, max_len=200)
X_valid = prepare_stroke_sequences(valid_data, max_len=200)

y_train = train_sequences   # already tokenized and padded
y_valid = valid_sequences

print("Input shape:", X_train.shape)
print("Output shape:", y_train.shape)

# -------------------------------
# 2. Build Seq2Seq Model
# -------------------------------
max_seq_len = y_train.shape[1]   # max length of LaTeX tokens
vocab_size = len(tokenizer.word_index) + 1  # +1 for padding

inputs = Input(shape=(200, 3))   # 200 timesteps, (dx, dy, pen_state)
x = Masking(mask_value=0.0)(inputs)
x = LSTM(256, return_sequences=True)(x)
x = LSTM(256)(x)

# Output: LaTeX tokens
outputs = Dense(vocab_size, activation="softmax")(x)

model = Model(inputs, outputs)
model.compile(optimizer=Adam(1e-3), loss="sparse_categorical_crossentropy", metrics=["accuracy"])

model.summary()

# -------------------------------
# 3. Training
# -------------------------------
# Since we're predicting the whole sequence, we shift y to categorical
y_train_cls = y_train[:, 0]   # predict only first token (baseline simplification)
y_valid_cls = y_valid[:, 0]

history = model.fit(
    X_train, y_train_cls,
    validation_data=(X_valid, y_valid_cls),
    batch_size=64,
    epochs=5
)


Input shape: (229864, 200, 3)
Output shape: (229864, 10)


Epoch 1/5
[1m3592/3592[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3449s[0m 959ms/step - accuracy: 0.0168 - loss: 10.0684 - val_accuracy: 0.0226 - val_loss: 10.6059
Epoch 2/5
[1m3592/3592[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3253s[0m 906ms/step - accuracy: 0.0180 - loss: 9.6168 - val_accuracy: 0.0230 - val_loss: 11.0873
Epoch 3/5
[1m3592/3592[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3388s[0m 943ms/step - accuracy: 0.0221 - loss: 9.2702 - val_accuracy: 0.0249 - val_loss: 11.2937
Epoch 4/5
[1m3592/3592[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3251s[0m 905ms/step - accuracy: 0.0268 - loss: 8.6483 - val_accuracy: 0.0308 - val_loss: 11.4968
Epoch 5/5
[1m3592/3592[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3815s[0m 1s/step - accuracy: 0.0340 - loss: 7.9656 - val_accuracy: 0.0392 - val_loss: 11.4690
