# Train model

In [None]:
!pip install -q pix2tex[train] opencv-python-headless gpustat
!pip install --upgrade torch~=2.3.0 torchvision~=0.18.0 torchaudio~=2.3.0 torchtext~=0.18.0

In [None]:
!gpustat

Setup workspace

In [None]:
import os
from pathlib import Path

BASE_DIR = Path("workspace")
BASE_DIR.mkdir(exist_ok=True)
os.chdir(BASE_DIR)

Download resources

In [None]:
import zipfile
import concurrent.futures
import requests
import hashlib
import time
from pathlib import Path

def compute_sha256(file_path: Path) -> str:
    """Compute SHA-256 hash of a file."""
    sha256 = hashlib.sha256()
    with open(file_path, "rb") as f:
        while chunk := f.read(4096):
            sha256.update(chunk)
    return sha256.hexdigest()

def download_file(url: str, path: Path, expected_hash: str = None, max_retries=3, timeout=10):
    """Download file with retries and resumption support."""
    if expected_hash and compute_sha256(path) == expected_hash:
        print(f"{path.name} exists, skipping.")
        return

    retries = 0
    while retries < max_retries:
        try:
            print(f"Downloading {path.name} (Attempt {retries+1}/{max_retries})...")
            headers = {}

            if path.exists():
                downloaded_size = path.stat().st_size
                headers["Range"] = f"bytes={downloaded_size}-"

            response = requests.get(url, headers=headers, stream=True, timeout=timeout, allow_redirects=True)
            response.raise_for_status()

            mode = "ab" if "Range" in headers else "wb"
            with open(path, mode) as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)

            print(f"{path.name} download complete.")
            sha256_hash = compute_sha256(path)
            print(f"{path.name} SHA-256: {sha256_hash}")
            return

        except requests.RequestException as e:
            retries += 1
            print(f"Retry {retries}/{max_retries} for {path.name} due to {e}...")
            time.sleep(2 ** retries)

    print(f"Failed to download {path.name} after {max_retries} attempts.")
    path.unlink(missing_ok=True)

In [None]:
from pathlib import Path

DATASET_DIR = Path("dataset/data")

DATASET_DIR.mkdir(parents=True, exist_ok=True)

download_list = [
    ("https://drive.google.com/uc?id=1QUjX6PFWPa-HBWdcY-7bA5TRVUnbyS1D", DATASET_DIR / "pdfmath.txt"),
    ("https://github.com/lukas-blecher/LaTeX-OCR/releases/download/v0.0.1/weights.pth", Path("weights.pth"))
]

with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
    futures = [executor.submit(download_file, url, path) for url, path in download_list]
    concurrent.futures.wait(futures)

Extract resources

In [None]:
import zipfile
from pathlib import Path

def extract_zip(file_path: Path, extract_to: Path):
    print(f"Extracting {file_path}...")
    with zipfile.ZipFile(file_path, "r") as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extraction complete: {extract_to}")

In [None]:
import shutil
import random

IMAGES_DIR = DATASET_DIR / "images"

IMAGES_DIR.mkdir(parents=True, exist_ok=True)

extract_zip(DATASET_DIR / "crohme.zip", DATASET_DIR)
extract_zip(DATASET_DIR / "pdf.zip", DATASET_DIR)

Prepare data

In [None]:
VAL_DIR = DATASET_DIR / "valimages"

VAL_DIR.mkdir(parents=True, exist_ok=True)

image_files = list(IMAGES_DIR.glob("*"))
val_files = set(random.sample(image_files, 1000))

for file in val_files:
    dest = VAL_DIR / file.name
    if not dest.exists():
        shutil.move(str(file), str(dest))

In [None]:
!python -m pix2tex.dataset.dataset \
    -i dataset/data/images dataset/data/train \
    -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt \
    -o dataset/data/train.pkl

In [None]:
!python -m pix2tex.dataset.dataset \
    -i dataset/data/valimages dataset/data/val \
    -e dataset/data/CROHME_math.txt dataset/data/pdfmath.txt \
    -o dataset/data/val.pkl

In [None]:
%%writefile colab.yaml
backbone_layers: [2, 3, 7]
betas: [0.9, 0.999]
batchsize: 10
bos_token: 1
channels: 1
data: dataset/data/train.pkl
debug: true
decoder_args:
  attn_on_attn: true
  cross_attend: true
  ff_glu: true
  rel_pos_bias: false
  use_scalenorm: false
dim: 256
encoder_depth: 4
eos_token: 2
epochs: 50
gamma: 0.9995
heads: 8
id: null
load_chkpt: weights.pth
lr: 0.001
lr_step: 30
max_height: 192
max_seq_len: 512
max_width: 672
min_height: 32
min_width: 32
model_path: checkpoints
name: mixed
num_layers: 4
num_tokens: 8000
optimizer: Adam
output_path: outputs
pad: false
pad_token: 0
patch_size: 16
sample_freq: 2000
save_freq: 1
scheduler: StepLR
seed: 42
temperature: 0.2
test_samples: 5
testbatchsize: 20
tokenizer: dataset/tokenizer.json
valbatches: 100
valdata: dataset/data/val.pkl

In [None]:
!python -m pix2tex.train --config colab.yaml