In [1]:
import kagglehub

# Download latest version
DATA_PATH = kagglehub.dataset_download("theseus200719/math-equations-dataset-aidav7-modified")

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
import os
import ast
import pandas as pd
import numpy as np
from tqdm import tqdm

In [3]:
# Подгоняем изображения под общий размер, сохраняя отношение сторон. Паддим справа.
def resize_and_pad(img, target_height=64, target_width=256):
    h, w = img.shape
    scale = target_height / h
    new_w = int(w * scale)

    if new_w > target_width:
        new_w = target_width

    resized = cv2.resize(img, (new_w, target_height))

    padded = np.zeros((target_height, target_width), dtype=np.float32)
    padded[:, :new_w] = resized

    return padded

In [4]:
# ===== Dataset =====
class FormulaDataset(Dataset):
    def __init__(self, path, vocab, transform=None):
        self.data = pd.read_csv(os.path.join(path, 'annotations.csv'))
        self.img_dir = os.path.join(path, 'images')
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # --- картинка ---
        img_path = os.path.join(self.img_dir, row['filenames'])
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = resize_and_pad(img)  # ATTENTION!!!
        img = img / 255.0 # нормируем
        img = torch.tensor(img).unsqueeze(0).float()  # (1,H,W)

        # --- токены ---
        info = ast.literal_eval(row['image_data'])
        tokens = info['full_latex_chars']   # список строк-токенов
        token_ids = self.vocab.encode(tokens)

        return img, torch.tensor(token_ids)


Добавить чтение файла vocab.json с токенами.

In [5]:
# ===== Vocabulary =====
class Vocab:
    def __init__(self, token_list):

        # спец символы
        self.pad = "<pad>"
        self.bos = "<bos>"
        self.eos = "<eos>"
        self.unk = "<unk>"

        self.tokens = [self.pad, self.bos, self.eos, self.unk] + sorted(token_list)
        self.stoi = {t: i for i, t in enumerate(self.tokens)}
        self.itos = {i: t for t, i in self.stoi.items()}

        self.length = len(self.tokens)  # ATTENTION!!!

    def __len__(self):
        return self.length

    def encode(self, token_seq):
        # IMPROVED: Added input validation
        if not isinstance(token_seq, list):
            raise ValueError(f"Expected list of tokens, got {type(token_seq)}")
        
        if not all(isinstance(t, str) for t in token_seq):
            raise ValueError("All tokens must be strings")
        
        # Принимает список токенов.
        return [self.stoi[self.bos]] + \
               [self.stoi.get(t, self.stoi[self.unk]) for t in token_seq] + \
               [self.stoi[self.eos]]

    def decode(self, ids):
        # Декодируем и обрезаем по первому <eos>.
        eos_id = self.stoi[self.eos]
        pad_id = self.stoi[self.pad]
        bos_id = self.stoi[self.bos]

        toks = []
        for i in ids:
            if i == eos_id:
                break
            if i not in (pad_id, bos_id):
                toks.append(self.itos[i])

        return "".join(toks)


In [6]:
# Ищем уникальные токены из всего списка токенов
df = pd.read_csv(os.path.join(DATA_PATH, "annotations.csv"))

all_token_lists = []
for i, row in tqdm(df.iterrows(), total=len(df)):
    info = ast.literal_eval(row['image_data'])
    tokens = info['full_latex_chars']
    all_token_lists.append(tokens)

unique_tokens = set()
for seq in all_token_lists:
    unique_tokens.update(seq)

100%|██████████| 100000/100000 [01:12<00:00, 1388.48it/s]


In [7]:
import json

with open("tokens.json", "w") as f:
    json.dump(list(unique_tokens), f)

In [8]:
with open("tokens.json", "r") as f:
    token_list = json.load(f)

vocab = Vocab(token_list)
print(vocab.tokens)

['<pad>', '<bos>', '<eos>', '<unk>', '+', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '[', '\\cdot', '\\cos', '\\cot', '\\csc', '\\frac', '\\infty', '\\left(', '\\left|', '\\lim_', '\\ln', '\\log', '\\pi', '\\right)', '\\right|', '\\sec', '\\sin', '\\sqrt', '\\tan', '\\theta', '\\to', ']', '^', '_', 'a', 'b', 'c', 'd', 'e', 'g', 'h', 'k', 'n', 'p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '}']


In [9]:
def collate_fn(batch):
    imgs, seqs = zip(*batch)

    # картинки: можно просто в stack
    imgs = torch.stack(imgs, dim=0)

    # токены: паддим до максимальной длины в батче
    max_len = max(len(seq) for seq in seqs)
    padded_seqs = torch.full((len(seqs), max_len), fill_value=vocab.stoi[vocab.pad], dtype=torch.long)

    for i, seq in enumerate(seqs):
        padded_seqs[i, :len(seq)] = seq

    return imgs, padded_seqs

In [10]:
from torch.utils.data import DataLoader, random_split

# создаём датасет
dataset = FormulaDataset(DATA_PATH, vocab)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, pin_memory=True, collate_fn=collate_fn, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, pin_memory=True, collate_fn=collate_fn, num_workers=2)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- CNN encoder (ResNet-like) ---
class CNNEncoder(nn.Module):
    def __init__(self, hidden_dim=256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(128, hidden_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(hidden_dim), nn.ReLU(),
        )

    def forward(self, x):
        # [B,1,H,W] → [B,Hid,h,w]
        feats = self.conv(x)
        B, C, H, W = feats.shape
        # flatten → [B, HW, C]
        feats = feats.permute(0, 2, 3, 1).reshape(B, H*W, C)
        return feats  # [B, seq_len, hidden_dim]
# --- Transformer decoder ---
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim=256, num_layers=4, nhead=8, dropout=0.1, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, hidden_dim)
        self.pos_emb = nn.Embedding(max_len, hidden_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True,  # теперь входы [B,T,H], а не [T,B,H]
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, tgt, memory):
        """
        tgt: [B, T] токены
        memory: [B, S, H] фичи энкодера
        """
        B, T = tgt.shape
        tok_emb = self.token_emb(tgt)
        pos = self.pos_emb(torch.arange(T, device=tgt.device))  # [T,H]
        pos = pos.unsqueeze(0).expand(B, -1, -1)
        tgt_emb = tok_emb + pos  # [B,T,H]

        # TransformerDecoder c batch_first=True → ждёт [B,T,H], memory [B,S,H]
        out = self.transformer_decoder(tgt_emb, memory)  # [B,T,H]
        return self.fc_out(out)  # [B,T,V]

# --- Итоговая модель ---
class FormulaRecognizer(nn.Module):
    def __init__(self, vocab_size, hidden_dim=256):
        super().__init__()
        self.encoder = CNNEncoder(hidden_dim)
        self.decoder = TransformerDecoder(vocab_size, hidden_dim)

    def forward(self, images, tokens):
        memory = self.encoder(images)  # [B,S,H]
        out = self.decoder(tokens, memory)
        return out

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os
import math

def train_model(model, train_loader, val_loader, tokenizer, epochs=5, device="cuda"):
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi[tokenizer.pad])  # ignore padding
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    
    
    # Create models directory if it doesn't exist
    os.makedirs('models', exist_ok=True)
    
    model = model.to(device)
    best_val_loss = float('inf')

    for epoch in range(epochs):
        # ---------- TRAIN ----------
        model.train()
        total_loss = 0
        total_tokens = 0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")

        for imgs, tokens in train_bar:
            imgs, tokens = imgs.to(device), tokens.to(device)

            optimizer.zero_grad()
            
            # FIXED: Proper teacher forcing - input is tokens[:-1], target is tokens[1:]
            outputs = model(imgs, tokens[:, :-1])  # [B, T-1, vocab_size]
            targets = tokens[:, 1:]  # [B, T-1] - shifted targets (no BOS, includes EOS)
            
            # FIXED: Ensure same length for outputs and targets
            seq_len = min(outputs.size(1), targets.size(1))
            outputs = outputs[:, :seq_len, :]  # [B, seq_len, vocab_size]
            targets = targets[:, :seq_len]     # [B, seq_len]

            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                targets.reshape(-1)
            )

            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()

            total_loss += loss.item()
            total_tokens += targets.numel()
            train_bar.set_postfix(loss=loss.item(), ppl=math.exp(loss.item()))

        avg_train_loss = total_loss / len(train_loader)
        train_ppl = math.exp(avg_train_loss)

        # ---------- VALID ----------
        model.eval()
        val_loss = 0
        val_tokens = 0
        examples = []
        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
            for imgs, tokens in val_bar:
                imgs, tokens = imgs.to(device), tokens.to(device)
                
                # FIXED: Same teacher forcing logic as training
                outputs = model(imgs, tokens[:, :-1])
                targets = tokens[:, 1:]
                
                seq_len = min(outputs.size(1), targets.size(1))
                outputs = outputs[:, :seq_len, :]
                targets = targets[:, :seq_len]

                loss = criterion(
                    outputs.reshape(-1, outputs.size(-1)),
                    targets.reshape(-1)
                )
                val_loss += loss.item()
                val_tokens += targets.numel()
                
                if len(examples) < 3:
                    preds = outputs.argmax(-1)  # [B, seq_len]
                    pred_text = tokenizer.decode(preds[0].cpu().tolist())
                    true_text = tokenizer.decode(tokens[0].cpu().tolist())
                    examples.append((pred_text, true_text))

        avg_val_loss = val_loss / len(val_loader)
        val_ppl = math.exp(avg_val_loss)


        print(f"\nEpoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} (PPL: {train_ppl:.2f}) | Val Loss: {avg_val_loss:.4f} (PPL: {val_ppl:.2f})")
        for i, (pred, true) in enumerate(examples):
            print(f"  EX{i+1}: pred = {pred}")
            print(f"       true = {true}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), os.path.join('models', 'best_model.pth'))
            print(f"New best model saved with val_loss: {avg_val_loss:.4f}")
        
        # Save checkpoint
        torch.save(model.state_dict(), os.path.join('models', f"model_epoch{epoch+1}.pth"))
        print(f"Checkpoint saved: model_epoch{epoch+1}.pth")


In [1]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = FormulaRecognizer(vocab_size=len(vocab))

train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    tokenizer=vocab,
    epochs=10,
    device=device
)


NameError: name 'torch' is not defined