In [None]:
!pip install transformers datasets torchvision matplotlib accelerate evaluate jiwer sacrebleu rouge-score

In [None]:
from datasets import load_dataset
import torch

train_dataset = load_dataset("linxy/LaTeX_OCR", name="full", split="train")
val_dataset = load_dataset("linxy/LaTeX_OCR", name="full", split="validation")

print(f"Train size: {len(train_dataset)}")
print(f"Val size: {len(val_dataset)}")

In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [None]:
def preprocess(example):
    image = example["image"].convert("RGB")
    pixel_values = processor(images=image, return_tensors="pt").pixel_values[0]
    input_ids = processor.tokenizer(
        example["text"],
        padding="max_length",
        max_length=256,
        truncation=True,
        return_tensors="pt"
    ).input_ids[0]
    example["pixel_values"] = pixel_values
    example["labels"] = input_ids
    return example

train_dataset = train_dataset.map(preprocess)
val_dataset = val_dataset.map(preprocess)

In [None]:
from torch.utils.data import Dataset

class LaTeXDataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "pixel_values": item["pixel_values"],
            "labels": item["labels"],
        }

train_torch_ds = LaTeXDataset(train_dataset)
val_torch_ds = LaTeXDataset(val_dataset)

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="./trocr_latex",
    eval_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    logging_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=100,
    save_total_limit=1,
    logging_dir="./logs",
    report_to="none",
    fp16=torch.cuda.is_available(),
)

In [None]:
import evaluate

cer_metric = evaluate.load("cer")
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    # Простейшая посимвольная точность
    total_chars = sum(len(label) for label in label_str)
    correct_chars = sum(p == l for p_seq, l_seq in zip(pred_str, label_str)
                        for p, l in zip(p_seq, l_seq))
    acc = correct_chars / total_chars if total_chars > 0 else 0.0

    return {
        "cer": cer,
        "wer": wer,
        "char_accuracy": acc,
    }

from transformers import default_data_collator

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,  # можно оставить, warning не критичен
    data_collator=default_data_collator,
    compute_metrics=compute_metrics  # если используешь
)

trainer.train()

In [None]:
import matplotlib.pyplot as plt

log_history = trainer.state.log_history

# Собираем все значения
train_loss = [log["loss"] for log in log_history if "loss" in log]
eval_loss = [log["eval_loss"] for log in log_history if "eval_loss" in log]
cer = [log["eval_cer"] for log in log_history if "eval_cer" in log]
wer = [log["eval_wer"] for log in log_history if "eval_wer" in log]
char_acc = [log["eval_char_accuracy"] for log in log_history if "eval_char_accuracy" in log]

epochs = list(range(1, len(eval_loss) + 1))

# Отдельный график для Loss
plt.figure(figsize=(8, 6))
plt.plot(epochs, train_loss, label="Train Loss")
plt.plot(epochs, eval_loss, label="Val Loss")
plt.title("График функции потерь")
plt.xlabel("Epochs")
plt.ylabel("Loss-function")
plt.legend()
plt.grid(True)
plt.savefig("train_val_loss.png")

# Отдельный график для CER
plt.figure(figsize=(8, 6))
plt.plot(epochs, cer, label="CER", color='green')
plt.title("Character Error Rate (CER) Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("CER")
plt.grid(True)
plt.savefig("Character Error Rate (CER).png")

# Отдельный график для WER
plt.figure(figsize=(8, 6))
plt.plot(epochs, wer, label="WER", color='orange')
plt.title("Word Error Rate (WER) Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("WER")
plt.grid(True)
plt.savefig("Word Error Rate (WER).png")

# Отдельный график для Char Accuracy
plt.figure(figsize=(8, 6))
plt.plot(epochs, char_acc, label="Accuracy", color='blue')
plt.title("Accuracy Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.grid(True)
plt.savefig("Accuracy.png")