In [1]:
import kagglehub

# Download latest version
DATA_PATH = kagglehub.dataset_download("theseus200719/math-equations-dataset-aidav7-modified")

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import cv2
import os
import ast
import numpy as np
from tqdm import tqdm
import json

In [8]:
# Подгоняем изображения под общий размер, сохраняя отношение сторон. Паддим справа.
def resize_and_pad(img, target_height=64, target_width=256):
    h, w = img.shape
    scale = target_height / h
    new_w = int(w * scale)

    if new_w > target_width:
        new_w = target_width

    resized = cv2.resize(img, (new_w, target_height))

    padded = np.zeros((target_height, target_width), dtype=np.float32)
    padded[:, :new_w] = resized

    return padded

In [9]:
# ===== Dataset =====
class FormulaDataset(Dataset):
    def __init__(self, path, vocab, transform=None):
        self.data = pd.read_csv(os.path.join(path, 'annotations.csv'))
        self.img_dir = os.path.join(path, 'images')
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # --- картинка ---
        img_path = os.path.join(self.img_dir, row['filenames'])
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = resize_and_pad(img)  # ATTENTION!!!
        img = img / 255.0 # нормируем
        img = torch.tensor(img).unsqueeze(0).float()  # (1,H,W)

        # --- токены ---
        info = ast.literal_eval(row['image_data'])
        tokens = info['full_latex_chars']   # список строк-токенов
        token_ids = self.vocab.encode(tokens)

        return img, torch.tensor(token_ids)


In [10]:
# ===== Vocabulary =====
class Vocab:
    def __init__(self, token_list):

        # спец символы
        self.pad = "<pad>"
        self.bos = "<bos>"
        self.eos = "<eos>"
        self.unk = "<unk>"

        self.tokens = [self.pad, self.bos, self.eos, self.unk] + sorted(token_list)
        self.stoi = {t: i for i, t in enumerate(self.tokens)}
        self.itos = {i: t for t, i in self.stoi.items()}

        self.length = len(self.tokens)  # ATTENTION!!!

    def __len__(self):
        return self.length

    def encode(self, token_seq):
        if not isinstance(token_seq, list):
            raise ValueError(f"Expected list of tokens, got {type(token_seq)}")
        
        if not all(isinstance(t, str) for t in token_seq):
            raise ValueError("All tokens must be strings")
        
        return [self.stoi[self.bos]] + \
               [self.stoi.get(t, self.stoi[self.unk]) for t in token_seq] + \
               [self.stoi[self.eos]]

    def decode(self, ids):
        # Декодируем и обрезаем по первому <eos>.
        eos_id = self.stoi[self.eos]
        pad_id = self.stoi[self.pad]
        bos_id = self.stoi[self.bos]

        toks = []
        for i in ids:
            if i == eos_id:
                break
            # Пропускаем pad и bos токены
            if i not in (pad_id, bos_id):
                toks.append(self.itos[i])

        return "".join(toks)


In [11]:
# Ищем уникальные токены из всего списка токенов
df = pd.read_csv(os.path.join(DATA_PATH, "annotations.csv"))

all_token_lists = []
for i, row in tqdm(df.iterrows(), total=len(df)):
    info = ast.literal_eval(row['image_data'])
    tokens = info['full_latex_chars']
    all_token_lists.append(tokens)

unique_tokens = set()
for seq in all_token_lists:
    unique_tokens.update(seq)

100%|██████████| 100000/100000 [01:06<00:00, 1502.19it/s]


In [12]:
with open("tokens.json", "w") as f:
    json.dump(list(unique_tokens), f)

In [13]:
with open("tokens.json", "r") as f:
    token_list = json.load(f)

vocab = Vocab(token_list)
print(vocab.tokens)

['<pad>', '<bos>', '<eos>', '<unk>', '+', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '[', '\\cdot', '\\cos', '\\cot', '\\csc', '\\frac', '\\infty', '\\left(', '\\left|', '\\lim_', '\\ln', '\\log', '\\pi', '\\right)', '\\right|', '\\sec', '\\sin', '\\sqrt', '\\tan', '\\theta', '\\to', ']', '^', '_', 'a', 'b', 'c', 'd', 'e', 'g', 'h', 'k', 'n', 'p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '}']


In [14]:
def collate_fn(batch):
    imgs, seqs = zip(*batch)

    # картинки: можно просто в stack
    imgs = torch.stack(imgs, dim=0)

    # токены: паддим до максимальной длины в батче
    max_len = max(len(seq) for seq in seqs)
    padded_seqs = torch.full((len(seqs), max_len), fill_value=vocab.stoi[vocab.pad], dtype=torch.long)

    for i, seq in enumerate(seqs):
        padded_seqs[i, :len(seq)] = seq

    return imgs, padded_seqs

In [15]:
# создаём датасет
dataset = FormulaDataset(DATA_PATH, vocab)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, pin_memory=True, collate_fn=collate_fn)

In [16]:
# --- CNN encoder ---
class CNNEncoder(nn.Module):
    def __init__(self, hidden_dim=256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(128, hidden_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(hidden_dim), nn.ReLU(),
        )

    def forward(self, x):
        feats = self.conv(x)  # [B,Hid,h,w]
        B, C, H, W = feats.shape
        # Преобразуем в последовательность: [B, H*W, C]
        feats = feats.permute(0, 2, 3, 1).reshape(B, H*W, C)
        return feats  # [B, seq_len, hidden_dim]


In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim=256, num_layers=4, nhead=8, dropout=0.1, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, hidden_dim)
        self.pos_emb = nn.Embedding(max_len, hidden_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, tgt, memory):
        B, T = tgt.shape
        tok_emb = self.token_emb(tgt)  # [B,T,H]
        pos = self.pos_emb(torch.arange(T, device=tgt.device))  # [T,H]
        pos = pos.unsqueeze(0).expand(B, -1, -1)
        tgt_emb = tok_emb + pos

        # маска [T,T]
        tgt_mask = torch.triu(torch.ones(T, T, device = tgt.device), diagonal=1).bool()

        out = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask)  # [B,T,H]
        return self.fc_out(out)

In [36]:
# --- Итоговая модель ---
class FormulaRecognizer(nn.Module):
    def __init__(self, vocab_size, hidden_dim=256, max_len=512):
        super().__init__()
        self.encoder = CNNEncoder(hidden_dim)
        self.decoder = TransformerDecoder(vocab_size, hidden_dim, max_len=max_len)
        self.max_len = max_len
        self.vocab_size = vocab_size

    def forward(self, images, tokens):
        memory = self.encoder(images)  # [B,S,H]
        out = self.decoder(tokens[:, :-1], memory)  # предсказываем без последнего токена
        return out  # [B,T-1,V]

    def greedy_decode(self, image, start_token, end_token, device="cpu"):
        self.eval()
        with torch.no_grad():
            memory = self.encoder(image.unsqueeze(0).to(device))  # [1,S,H]

            tokens = torch.tensor([[start_token]], device=device)
            for _ in range(self.max_len):
                out = self.decoder(tokens, memory)  # [1,T,V]
                next_token = out[:, -1, :].argmax(-1)  # [1]
                tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=1)
                if next_token.item() == end_token:
                    break

        return tokens.squeeze(0).tolist()

In [37]:
def train_model(model, train_loader, val_loader, tokenizer, epochs=5, device="cuda"):
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # паддинг
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    model = model.to(device)

    for epoch in range(epochs):
        # ---------- TRAIN ----------
        model.train()
        total_loss = 0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")

        for imgs, tokens in train_bar:
            imgs, tokens = imgs.to(device), tokens.to(device)

            optimizer.zero_grad()
            outputs = model(imgs, tokens[:, :-1])  # [B, T-1, vocab_size]

            # делаем срез по минимальной длине
            min_len = min(outputs.size(1), tokens[:, 1:].size(1))
            outputs = outputs[:, :min_len, :]
            targets = tokens[:, 1:1+min_len]

            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                targets.reshape(-1)
            )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            train_bar.set_postfix(loss=loss.item())

        avg_train_loss = total_loss / len(train_loader)

        # ---------- VALID ----------
        model.eval()
        val_loss = 0
        examples = []
        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
            for imgs, tokens in val_bar:
                imgs, tokens = imgs.to(device), tokens.to(device)
                outputs = model(imgs, tokens[:, :-1])

                min_len = min(outputs.size(1), tokens[:, 1:].size(1))
                outputs = outputs[:, :min_len, :]
                targets = tokens[:, 1:1+min_len]

                loss = criterion(
                    outputs.reshape(-1, outputs.size(-1)),
                    targets.reshape(-1)
                )
                val_loss += loss.item()

                if len(examples) < 3:
                    preds = outputs.argmax(-1)  # [B, T]
                    pred_text = tokenizer.decode([t for t in preds[0].cpu().tolist() if t != 0])
                    true_text = tokenizer.decode([t for t in tokens[0].cpu().tolist() if t != 0])
                    examples.append((pred_text, true_text))

        avg_val_loss = val_loss / len(val_loader)

        print(f"\nEpoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        for i, (pred, true) in enumerate(examples):
            print(f"  EX{i+1}: pred = {pred}")
            print(f"       true = {true}")

        torch.save(model.state_dict(), f"model_epoch{epoch+1}.pth")
        print(f"Model saved: model_epoch{epoch+1}.pth")


In [38]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = FormulaRecognizer(vocab_size=len(vocab))

train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    tokenizer=vocab,
    epochs=10,
    device=device
)


Epoch 1/10 [Train]:   1%|          | 15/1250 [00:07<10:50,  1.90it/s, loss=2.38]


KeyboardInterrupt: 