# Im2LaTeX Training - Complete Notebook

This notebook combines all training code for mathematical formula recognition using CNN + LSTM architecture.

Optimized for **Kaggle** and **Google Colab**.

## 1. Setup Environment & Install Dependencies

In [None]:
import sys
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from tqdm import tqdm
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter
import requests
import zipfile
import re


# Check if we're in Colab or Kaggle
IN_COLAB = 'google.colab' in str(get_ipython()) if hasattr(__builtins__, '__IPYTHON__') else False
IN_KAGGLE = 'KAGGLE_URL_BASE' in os.environ

print(f"Running in: {'Google Colab' if IN_COLAB else 'Kaggle' if IN_KAGGLE else 'Local'}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

Running in: Google Colab
PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4
GPU Memory: 14.7 GB


## 2. Data Preparation & Download

In [None]:
# Simple dataset download - CORRECT VERSION
import os
import requests
import tarfile
import zipfile


def unpack_zip(file_name, destination_folder):
    """Extract ZIP file to destination folder"""
    print(f"üì¶ Extracting {file_name} to {destination_folder}...")
    with zipfile.ZipFile(file_name, 'r') as zip_ref:
        zip_ref.extractall(destination_folder)
    os.remove(file_name)
    print("‚úÖ ZIP extraction completed!")

def unpack_targz(file_name, destination_folder):
    """Extract .tar.gz file to destination folder"""
    print(f"üì¶ Extracting {file_name} to {destination_folder}...")
    with tarfile.open(file_name, 'r:gz') as tar_ref:
        tar_ref.extractall(destination_folder)
    os.remove(file_name)
    print("‚úÖ TAR.GZ extraction completed!")

def download_dataset():
    print("üì• Downloading Im2LaTeX dataset...")

    url = "https://zenodo.org/api/records/56198/files-archive"

    file_name = 'im2latex.zip'
    destination_folder = 'datasets/im2latex/'

    # Download
    response = requests.get(url)
    with open(file_name, 'wb') as f:
        f.write(response.content)
    print(f"‚úÖ Downloaded: {len(response.content) / (1024*1024):.1f} MB")

    # Ensure the destination folder exists
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)

    unpack_zip(file_name, destination_folder)

    file_name = 'datasets/im2latex/formula_images.tar.gz'
    destination_folder = 'datasets/im2latex/'

    # Ensure the destination folder exists
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)

    unpack_targz(file_name, destination_folder)

    # Show what we got
    print("\nüìÅ Files extracted:")
    for item in os.listdir("datasets/im2latex"):
        print(f"   - {item}")

# Run it
download_dataset()
# Create directory
os.makedirs('/content/models', exist_ok=True)

üì• Downloading Im2LaTeX dataset...
‚úÖ Downloaded: 292.6 MB
üì¶ Extracting im2latex.zip to datasets/im2latex/...
‚úÖ ZIP extraction completed!
üì¶ Extracting datasets/im2latex/formula_images.tar.gz to datasets/im2latex/...


  tar_ref.extractall(destination_folder)


‚úÖ TAR.GZ extraction completed!

üìÅ Files extracted:
   - im2latex_validate.lst
   - formula_images
   - im2latex_train.lst
   - im2latex_formulas.lst
   - readme.txt
   - im2latex_test.lst


In [None]:
import json
from pathlib import Path
from PIL import Image
import sys
import re
from collections import Counter

ROOT = Path("/content")
DATASET_ROOT = ROOT / "datasets" / "im2latex"
OUTPUT_PATH = ROOT / "datasets" / "im2latex_prepared_tokens.json"
VOCAB_PATH = ROOT / "models" / "vocab.json"  # New vocab path

FORMULA_FILE = DATASET_ROOT / "im2latex_formulas.lst"
TRAIN_LST = DATASET_ROOT / "im2latex_train.lst"
VAL_LST = DATASET_ROOT / "im2latex_validate.lst"
TEST_LST = DATASET_ROOT / "im2latex_test.lst"
IMG_ROOT = DATASET_ROOT / "formula_images"

def tokenize_latex(latex_str):
    """Properly tokenize LaTeX into meaningful tokens."""
    # Remove comments first
    latex_str = re.sub(r'%.*$', '', latex_str, flags=re.MULTILINE)

    tokens = []
    i = 0
    n = len(latex_str)

    while i < n:
        # 1. LaTeX commands
        if latex_str[i] == '\\':
            cmd_match = re.match(r'\\([a-zA-Z]+|[^a-zA-Z])', latex_str[i:])
            if cmd_match:
                cmd = '\\' + cmd_match.group(1)
                tokens.append(cmd)
                i += len(cmd)
                continue

        # 2. Match environments
        if latex_str[i:i+6] == '\\begin':
            env_match = re.match(r'\\begin\{([^}]+)\}', latex_str[i:])
            if env_match:
                tokens.append('\\begin{' + env_match.group(1) + '}')
                i += len('\\begin{' + env_match.group(1) + '}')
                continue

        if latex_str[i:i+4] == '\\end':
            env_match = re.match(r'\\end\{([^}]+)\}', latex_str[i:])
            if env_match:
                tokens.append('\\end{' + env_match.group(1) + '}')
                i += len('\\end{' + env_match.group(1) + '}')
                continue

        # 3. Special characters
        if latex_str[i] in '{}[]()^_$':
            tokens.append(latex_str[i])
            i += 1
            continue

        # 4. Operators
        if latex_str[i] in '+-*/=<>':
            tokens.append(latex_str[i])
            i += 1
            continue

        # 5. Punctuation
        if latex_str[i] in ',.;!?':
            tokens.append(latex_str[i])
            i += 1
            continue

        # 6. Numbers
        if latex_str[i].isdigit():
            tokens.append(latex_str[i])
            i += 1
            continue

        # 7. Letters
        if latex_str[i].isalpha():
            tokens.append(latex_str[i])
            i += 1
            continue

        # 8. Whitespace (skip)
        if latex_str[i].isspace():
            i += 1
            continue

        # 9. Anything else
        tokens.append(latex_str[i])
        i += 1

    return tokens

def build_vocab_from_formulas(formulas, min_freq=2):
    """Build vocabulary from all formulas."""
    all_tokens = []

    for formula in formulas:
        tokens = tokenize_latex(formula)
        all_tokens.extend(tokens)

    token_counts = Counter(all_tokens)

    vocab = {
        '<SOS>': 0,
        '<EOS>': 1,
        '<PAD>': 2,
        '<UNK>': 3,
    }

    idx = len(vocab)

    # Add tokens that meet frequency threshold
    for token, count in token_counts.most_common():
        if count >= min_freq and token not in vocab:
            vocab[token] = idx
            idx += 1

    print(f"Built vocabulary with {len(vocab)} tokens")

    # Save vocabulary to /content/models/vocab.json
    VOCAB_PATH.parent.mkdir(parents=True, exist_ok=True)
    with open(VOCAB_PATH, 'w', encoding='utf-8') as f:
        json.dump(vocab, f, ensure_ascii=False, indent=2)
    print(f"‚úÖ Saved vocabulary to {VOCAB_PATH}")

    return vocab

def load_formulas():
    with open(FORMULA_FILE, "r", encoding="latin-1") as f:
        formulas = [line.strip() for line in f]
    print(f"Loaded {len(formulas)} formulas")
    return formulas

def load_split(lst_path, formulas, vocab):
    samples = []
    with open(lst_path, "r") as f:
        for line in f:
            f_id, img_name, _ = line.strip().split()
            f_id = int(f_id)
            img_path = IMG_ROOT / (img_name + ".png")
            if img_path.exists():
                latex_formula = formulas[f_id]
                tokens = tokenize_latex(latex_formula)
                token_ids = [vocab.get(token, vocab['<UNK>']) for token in tokens]

                samples.append({
                    "img_path": str(img_path),
                    "tokens": token_ids,  # Store token IDs instead of string
                    "latex": latex_formula  # Keep original for reference
                })
    print(f"Loaded {len(samples)} samples from {lst_path.name}")
    return samples

def main():
    # Load formulas and build vocabulary
    formulas = load_formulas()
    vocab = build_vocab_from_formulas(formulas, min_freq=2)

    # Load splits with tokenized data
    dataset = {
        "train": load_split(TRAIN_LST, formulas, vocab),
        "val": load_split(VAL_LST, formulas, vocab),
        "test": load_split(TEST_LST, formulas, vocab),
        "vocab": vocab  # Include vocabulary in the dataset
    }

    with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
        json.dump(dataset, f, indent=2)

    print(f"\nSaved tokenized dataset to: {OUTPUT_PATH}")
    print(f"Dataset includes vocabulary with {len(vocab)} tokens")

if __name__ == "__main__":
    main()

Loaded 104564 formulas
Built vocabulary with 549 tokens
‚úÖ Saved vocabulary to /content/models/vocab.json
Loaded 83884 samples from im2latex_train.lst
Loaded 9320 samples from im2latex_validate.lst
Loaded 10355 samples from im2latex_test.lst

Saved tokenized dataset to: /content/datasets/im2latex_prepared_tokens.json
Dataset includes vocabulary with 549 tokens


## 3. Preprocessing Module

In [None]:
def detect_content_bounding_box(img, background_threshold=240, padding=10):
    """Detect the bounding box of actual content (non-background) in the image."""
    if len(img.shape) == 3 and img.shape[2] == 4:  # RGBA
        alpha = img[:, :, 3]
        mask = alpha > 10  # Non-transparent areas
    else:
        mask = img < background_threshold

    if not np.any(mask):
        return 0, 0, img.shape[1], img.shape[0]

    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)

    y1, y2 = np.where(rows)[0][[0, -1]]
    x1, x2 = np.where(cols)[0][[0, -1]]

    y1 = max(0, y1 - padding)
    y2 = min(img.shape[0], y2 + padding)
    x1 = max(0, x1 - padding)
    x2 = min(img.shape[1], x2 + padding)

    return x1, y1, x2, y2

def smart_crop_and_resize(img_gray, target_h=128, target_w=512, min_aspect_ratio=0.2, max_aspect_ratio=5.0):
    """Smart cropping and resizing that preserves content readability."""
    x1, y1, x2, y2 = detect_content_bounding_box(img_gray)
    cropped = img_gray[y1:y2, x1:x2]

    if cropped.size == 0:
        cropped = img_gray

    h, w = cropped.shape
    current_aspect = w / h
    constrained_aspect = max(min_aspect_ratio, min(max_aspect_ratio, current_aspect))

    if constrained_aspect > (target_w / target_h):
        new_w = target_w
        new_h = int(target_w / constrained_aspect)
    else:
        new_h = target_h
        new_w = int(target_h * constrained_aspect)

    new_w = max(new_w, 32)
    new_h = max(new_h, 32)

    resized = cv2.resize(cropped, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    padded = np.ones((target_h, target_w), dtype=np.uint8) * 255

    y_offset = (target_h - new_h) // 2
    x_offset = (target_w - new_w) // 2

    padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized

    return padded

def adaptive_binarize(img_gray, method='otsu', block_size=35, C=10):
    """Adaptive binarization for mathematical formulas."""
    if method == 'otsu':
        _, binary = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    else:
        binary = cv2.adaptiveThreshold(
            img_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY, block_size, C
        )
    return binary

def preprocess_for_model(image_path: str, target_h=128, target_w=512):
    """Improved preprocessing for mathematical formula images."""
    img = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
    if img is None:
        raise ValueError(f"Could not load image: {image_path}")

    if len(img.shape) == 3:
        if img.shape[2] == 4:  # RGBA
            rgb = img[:, :, :3]
            alpha = img[:, :, 3]
            img_gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
            img_gray[alpha < 10] = 255
        else:  # RGB
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    else:
        img_gray = img

    processed = smart_crop_and_resize(img_gray, target_h=target_h, target_w=target_w)
    binary_adaptive = adaptive_binarize(processed, method='adaptive')
    binary_otsu = adaptive_binarize(processed, method='otsu')

    adaptive_non_white = np.sum(binary_adaptive < 128)
    otsu_non_white = np.sum(binary_otsu < 128)

    if adaptive_non_white > otsu_non_white * 0.3:
        binary = binary_adaptive
    else:
        binary = binary_otsu

    kernel = np.ones((2, 2), np.uint8)
    cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
    normalized = (255 - cleaned) / 255.0

    return normalized

## 4. Model Architecture

In [None]:
class Im2Latex(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=2, dropout=0.3):
        super(Im2Latex, self).__init__()

        # CNN Encoder for 128x512 images
        self.cnn_encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [64, 64, 256]

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [128, 32, 128]

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [256, 16, 64]

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # [512, 8, 32]

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # [512, 4, 32]

            nn.AdaptiveAvgPool2d((4, 32))  # [512, 4, 32]
        )

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=2)
        self.lstm = nn.LSTM(
            input_size=512 * 4 + embed_dim,  # CNN features + embeddings
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        self.output_proj = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, images, input_tokens):
        # CNN features
        cnn_features = self.cnn_encoder(images)  # [B, 512, 4, 32]
        batch_size = cnn_features.size(0)
        cnn_features = cnn_features.view(batch_size, 512 * 4, -1)  # [B, 2048, 32]
        cnn_features = cnn_features.permute(0, 2, 1)  # [B, 32, 2048]

        # Embed tokens
        token_embeddings = self.embedding(input_tokens)  # [B, seq_len, embed_dim]

        # Repeat CNN features for each time step
        seq_len = input_tokens.size(1)
        cnn_features = cnn_features.repeat(1, seq_len // 32 + 1, 1)[:, :seq_len, :]

        # Combine features
        combined = torch.cat([cnn_features, token_embeddings], dim=-1)
        combined = self.dropout(combined)

        # LSTM
        lstm_out, _ = self.lstm(combined)

        # Output projection
        output = self.output_proj(lstm_out)
        return output

## 5. Dataset & DataLoader

In [None]:
class Im2LatexDataset(Dataset):
    """Dataset with preloading for faster training."""

    def __init__(self, data_list, vocab, max_len=256, cache_images=True, preload_images=True):
        self.data = data_list
        self.vocab = vocab
        self.max_len = max_len
        self.sos_token = vocab.get('<SOS>', 0)
        self.eos_token = vocab.get('<EOS>', 1)
        self.pad_token = vocab.get('<PAD>', 2)
        self.unk_token = vocab.get('<UNK>', 3)

        self.cache_images = cache_images
        self.preload_images = preload_images
        self._image_cache = {}

        # Preload all images if requested
        if self.preload_images:
            print(f"üîÑ Preloading {len(data_list)} images...")
            self._preload_all_images()
        else:
            print(f"üìÅ Dataset created with {len(data_list)} samples (on-demand loading)")

    def _preload_all_images(self):
        """Preload all images into memory for faster training."""
        from concurrent.futures import ThreadPoolExecutor
        import threading

        loaded_count = 0
        lock = threading.Lock()

        def load_single_image(idx):
            nonlocal loaded_count
            item = self.data[idx]
            img_path = item['img_path']

            # Get full image path
            if not img_path.startswith('/content/'):
                full_path = f"/content/datasets/im2latex/{img_path}"
            else:
                full_path = img_path

            try:
                img = preprocess_for_model(full_path)
                img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0)

                with lock:
                    self._image_cache[img_path] = img_tensor
                    loaded_count += 1

                    if loaded_count % 1000 == 0:
                        print(f"  Loaded {loaded_count}/{len(self.data)} images...")

            except Exception as e:
                print(f"Error loading {full_path}: {e}")
                with lock:
                    self._image_cache[img_path] = torch.zeros(1, 128, 512, dtype=torch.float32)
                    loaded_count += 1

        # Use ThreadPoolExecutor for parallel loading
        with ThreadPoolExecutor(max_workers=4) as executor:
            list(tqdm(executor.map(load_single_image, range(len(self.data))),
                     total=len(self.data), desc="Preloading images"))

        print(f"‚úÖ Preloaded {loaded_count} images into memory")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = item['img_path']

        # Use pre-tokenized token IDs
        token_ids = item['tokens']

        # Get image tensor from cache or load on-demand
        if img_path in self._image_cache:
            img_tensor = self._image_cache[img_path]
        else:
            # Get full image path
            if not img_path.startswith('/content/'):
                full_path = f"/content/datasets/im2latex/{img_path}"
            else:
                full_path = img_path

            try:
                img = preprocess_for_model(full_path)
                img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0)

                # Cache if enabled
                if self.cache_images:
                    self._image_cache[img_path] = img_tensor

            except Exception as e:
                print(f"Error loading {full_path}: {e}")
                img_tensor = torch.zeros(1, 128, 512, dtype=torch.float32)

        # Build token sequence with SOS and EOS
        tokens = [self.sos_token] + token_ids + [self.eos_token]

        # Pad or truncate to max_len
        if len(tokens) > self.max_len:
            tokens = tokens[:self.max_len]
            tokens[-1] = self.eos_token
        else:
            tokens.extend([self.pad_token] * (self.max_len - len(tokens)))

        tokens_tensor = torch.tensor(tokens, dtype=torch.long)

        # Create target sequence (shifted by one for teacher forcing)
        target = tokens_tensor[1:].clone()
        input_seq = tokens_tensor[:-1]

        return img_tensor, input_seq, target

def collate_fn(batch):
    """Collate function for DataLoader."""
    imgs, input_seqs, targets = zip(*batch)
    imgs = torch.stack(imgs)
    input_seqs = torch.stack(input_seqs)
    targets = torch.stack(targets)
    return imgs, input_seqs, targets

## 6. Training & Validation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """One training epoch."""
    model.train()
    total_loss = 0.0
    num_batches = 0

    non_blocking = torch.cuda.is_available()
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", ncols=80)

    for imgs, input_seqs, targets in pbar:
        imgs = imgs.to(device, non_blocking=non_blocking)
        input_seqs = input_seqs.to(device, non_blocking=non_blocking)
        targets = targets.to(device, non_blocking=non_blocking)

        optimizer.zero_grad()
        logits = model(imgs, input_seqs)

        logits = logits.reshape(-1, logits.size(-1))
        targets = targets.reshape(-1)

        mask = (targets != 2)  # PAD token = 2
        if mask.sum() > 0:
            loss = criterion(logits[mask], targets[mask])
        else:
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        avg_loss = total_loss / num_batches
        postfix = {'loss': f'{loss.item():.3f}', 'avg': f'{avg_loss:.3f}'}

        if torch.cuda.is_available() and num_batches % 20 == 0:
            gpu_memory = torch.cuda.memory_allocated(0) / 1024**3
            postfix['GPU'] = f'{gpu_memory:.1f}GB'

        pbar.set_postfix(postfix)

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

def validate(model, dataloader, criterion, device, vocab):
    """Proper validation with actual text generation and exact matching"""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    exact_matches = 0
    total_samples = 0

    # Reverse vocab for decoding
    idx_to_char = {idx: char for char, idx in vocab.items()}

    with torch.no_grad():
        for imgs, input_seqs, targets in tqdm(dataloader, desc="Validation", ncols=100, leave=False):
            imgs = imgs.to(device)
            batch_size = imgs.size(0)

            # REAL PREDICTION: Generate text from scratch for each image
            batch_correct = 0
            for i in range(batch_size):
                # Get actual text from targets
                actual_tokens = targets[i].cpu().numpy()
                actual_text = ''.join([idx_to_char.get(tok, '?') for tok in actual_tokens
                                      if tok not in [0, 1, 2]])  # Remove SOS/EOS/PAD

                # Generate predicted text from image only
                predicted_text = generate_from_image(model, imgs[i:i+1], vocab, device, max_len=256)

                # Exact string comparison (since no whitespaces)
                if predicted_text == actual_text:
                    batch_correct += 1
                else:
                    # Print first few errors for debugging
                    if total_samples < 3:
                        print(f"  ‚ùå Pred: '{predicted_text}'")
                        print(f"  ‚úÖ Actual: '{actual_text}'")
                        print("  ---")

                total_samples += 1

            exact_matches += batch_correct

            # Keep original loss calculation for compatibility
            input_seqs = input_seqs.to(device)
            targets = targets.to(device)
            logits = model(imgs, input_seqs)
            logits = logits.reshape(-1, logits.size(-1))
            targets = targets.reshape(-1)
            mask = (targets != 2)  # PAD token
            if mask.sum() > 0:
                loss = criterion(logits[mask], targets[mask])
                total_loss += loss.item()
                num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    accuracy = exact_matches / total_samples if total_samples > 0 else 0.0

    print(f"  Validation Accuracy: {accuracy:.2%} ({exact_matches}/{total_samples} correct)")
    return avg_loss, accuracy

def generate_from_image(model, image, vocab, device, max_len=256):
    """Generate LaTeX text from a single image (autoregressive)"""
    model.eval()

    # Reverse vocabulary for decoding
    idx_to_char = {idx: char for char, idx in vocab.items()}
    sos_token = vocab.get('<SOS>', 0)
    eos_token = vocab.get('<EOS>', 1)

    # Start with SOS token
    current_token = torch.tensor([[sos_token]], device=device)
    generated_tokens = []

    with torch.no_grad():
        for step in range(max_len):
            # Get model prediction
            logits = model(image, current_token)
            next_token_logits = logits[0, -1, :]  # Get last token prediction

            # Greedy decoding - take most likely token
            next_token = torch.argmax(next_token_logits).item()

            # Stop if EOS token is generated
            if next_token == eos_token:
                break

            generated_tokens.append(next_token)
            current_token = torch.cat([
                current_token,
                torch.tensor([[next_token]], device=device)
            ], dim=1)

    # Convert tokens to string
    generated_text = ''.join([idx_to_char.get(tok, '?') for tok in generated_tokens])
    return generated_text

## 7. Main Training Loop

In [None]:
def main():
    # Training parameters
    json_path = "/content/datasets/im2latex_prepared_tokens.json"
    batch_size = 32  # Reduced for stability with preloading
    num_epochs = 30
    learning_rate = 1e-4
    max_len = 256
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

    # Load data
    print(f"Loading tokenized data from {json_path}...")
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    train_data = data.get('train', [])
    val_data = data.get('val', [])
    vocab = data.get('vocab', {})

    print(f"Train samples: {len(train_data)}")
    print(f"Val samples: {len(val_data)}")
    print(f"Vocabulary size: {len(vocab)}")

    vocab_size = len(vocab)

    # Create datasets WITH PRELOADING
    print("Creating datasets with preloading...")
    train_data_cropped = train_data[:20000]
    val_data_cropped = val_data[:2000]
    train_dataset = Im2LatexDataset(
        train_data_cropped,
        vocab,
        max_len=max_len,
        cache_images=True,
        preload_images=True  # Enable preloading for training
    )
    val_dataset = Im2LatexDataset(
        val_data_cropped,
        vocab,
        max_len=max_len,
        cache_images=True,
        preload_images=True  # Enable preloading for validation
    ) if val_data_cropped else None

    # DataLoader with optimized settings
    num_workers = 4 if torch.cuda.is_available() else 0  # Increased workers for faster loading
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available()
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available()
    ) if val_dataset else None

    # Create model
    print("Creating model...")
    model = Im2Latex(vocab_size=vocab_size).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=2)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=2
    )

    # Training
    print("\nStarting training with preloaded data...")
    best_val_loss = float('inf')

    for epoch in range(1, num_epochs + 1):
        # Training
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch)
        print(f"Epoch {epoch}/{num_epochs} - Train Loss: {train_loss:.4f}")

        # Validation
        if val_loader:
            val_loss, val_accuracy = validate(model, val_loader, criterion, device, vocab)
            print(f"Epoch {epoch}/{num_epochs} - Val Loss: {val_loss:.4f} - Val Acc: {val_accuracy:.2%}")
            scheduler.step(val_loss)

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_path = "/content/models/im2latex_best.pth"
                torch.save(model.state_dict(), best_model_path)
                print(f"Saved best model to {best_model_path}")
        else:
            scheduler.step(train_loss)

        # Save checkpoint every 5 epochs
        if epoch % 5 == 0:
            checkpoint_path = f"/content/models/im2latex_epoch{epoch}.pth"
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'vocab_size': vocab_size
            }
            torch.save(checkpoint, checkpoint_path)
            print(f"Saved checkpoint to {checkpoint_path}")

        # Save vocabulary
        vocab_path = "/content/models/vocab.json"
        with open(vocab_path, 'w', encoding='utf-8') as f:
            json.dump(vocab, f, ensure_ascii=False, indent=2)

    print("\nTraining completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Models saved to /content/models/")

## 8. Run Training

In [None]:
# Execute the training
main()

Using device: cuda
CUDA device: Tesla T4
CUDA memory: 14.74 GB
Loading tokenized data from /content/datasets/im2latex_prepared_tokens.json...
Train samples: 83884
Val samples: 9320
Vocabulary size: 549
Creating datasets with preloading...
üîÑ Preloading 20000 images...


Preloading images:   5%|‚ñå         | 1002/20000 [01:00<16:56, 18.69it/s]

  Loaded 1000/20000 images...


Preloading images:  10%|‚ñà         | 2001/20000 [01:58<17:05, 17.54it/s]

  Loaded 2000/20000 images...


Preloading images:  15%|‚ñà‚ñå        | 3003/20000 [02:57<14:57, 18.93it/s]

  Loaded 3000/20000 images...


Preloading images:  20%|‚ñà‚ñà        | 4002/20000 [03:55<15:05, 17.68it/s]

  Loaded 4000/20000 images...


Preloading images:  25%|‚ñà‚ñà‚ñå       | 5000/20000 [04:53<15:00, 16.66it/s]

  Loaded 5000/20000 images...


Preloading images:  30%|‚ñà‚ñà‚ñà       | 6000/20000 [05:49<11:18, 20.64it/s]

  Loaded 6000/20000 images...


Preloading images:  35%|‚ñà‚ñà‚ñà‚ñå      | 7003/20000 [06:47<12:31, 17.30it/s]

  Loaded 7000/20000 images...


Preloading images:  40%|‚ñà‚ñà‚ñà‚ñà      | 8002/20000 [07:44<11:22, 17.59it/s]

  Loaded 8000/20000 images...


Preloading images:  45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 9000/20000 [08:41<15:01, 12.20it/s]

  Loaded 9000/20000 images...


Preloading images:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 10001/20000 [09:38<09:06, 18.31it/s]

  Loaded 10000/20000 images...


Preloading images:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 11001/20000 [10:37<12:40, 11.84it/s]

  Loaded 11000/20000 images...


Preloading images:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 11997/20000 [11:33<07:54, 16.87it/s]

  Loaded 12000/20000 images...


Preloading images:  65%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 13003/20000 [12:31<07:02, 16.56it/s]

  Loaded 13000/20000 images...


Preloading images:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 14002/20000 [13:28<04:50, 20.66it/s]

  Loaded 14000/20000 images...


Preloading images:  75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 15001/20000 [14:25<03:57, 21.09it/s]

  Loaded 15000/20000 images...


Preloading images:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 16002/20000 [15:23<03:39, 18.25it/s]

  Loaded 16000/20000 images...


Preloading images:  85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 16998/20000 [16:20<02:45, 18.19it/s]

  Loaded 17000/20000 images...


Preloading images:  90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 17998/20000 [17:19<02:02, 16.34it/s]

  Loaded 18000/20000 images...


Preloading images:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 19002/20000 [18:16<00:50, 19.77it/s]

  Loaded 19000/20000 images...


Preloading images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20000/20000 [19:14<00:00, 17.33it/s]


  Loaded 20000/20000 images...
‚úÖ Preloaded 20000 images into memory
üîÑ Preloading 2000 images...


Preloading images:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 1002/2000 [00:56<00:51, 19.53it/s]

  Loaded 1000/2000 images...


Preloading images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [01:54<00:00, 17.44it/s]


  Loaded 2000/2000 images...
‚úÖ Preloaded 2000 images into memory
Creating model...
Model parameters: 12,204,325

Starting training with preloaded data...


Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 625/625 [03:03<00:00,  3.40it/s, loss=4.119, avg=4.223]


Epoch 1/30 - Train Loss: 4.2234


Validation:   0%|                                                            | 0/63 [00:00<?, ?it/s]

  ‚ùå Pred: '\label{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}'
  ‚úÖ Actual: '\int_{-\epsilon}^\inftydl\:{\rme}^{-l\zeta}\int_{-\epsilon}^\inftydl'{\rme}^{-l'\zeta}ll'{l'-l\overl+l'}\{3\,\delta''(l)-{3\over4}t\,\delta(l)\}=0.\label{eq21}'
  ---
  ‚ùå Pred: '\label{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}'
  ‚úÖ Actual: '\label{hR=hR+hR+}\hat{R}=d\hat{C}-\hat{C}\wedge\hat{H}_3=\hat{R}_2\oplus\hat{R}_4\oplus\hat{R}_6\oplus\hat{R}_8\oplus\hat{R}_{10}\;,'
  ---
  ‚ùå Pred: '\label{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}

Validation:  16%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                         | 10/63 [11:32<1:01:06, 69.19s/it]

## 9. Download Trained Models (Colab)

In [None]:
# Download trained models from Colab
if IN_COLAB:
    from google.colab import files

    # Create zip of models
    import shutil
    shutil.make_archive('im2latex_models', 'zip', '/content/models')

    # Download
    files.download('im2latex_models.zip')
    print("Models downloaded!")
else:
    print("Models saved in /content/models/")