In [14]:
import cv2
import numpy as np
import xmltodict
import os
import glob
import re
import json

In [5]:
def extract_label_from_inkml(inkml_path):
    with open(inkml_path, 'r', encoding='utf-8') as f:
        doc = xmltodict.parse(f.read())

    annotations = doc['ink'].get('annotation', [])

    if isinstance(annotations, dict):  # Handle single annotation case
        annotations = [annotations]

    label = "UNKNOWN"
    for annotation in annotations:
        if annotation.get('@type', '') == 'truth':
            label = annotation['#text'].strip()  # Remove leading/trailing spaces
            label = re.sub(r'\s+', ' ', label)  # Replace multiple spaces with a single space
            break

    return label


In [66]:
def preprocess_image(image_path, target_size=(256, 256)):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    # Step 1: Apply Gaussian Blur to reduce noise while preserving edges
    img = cv2.GaussianBlur(img, (5, 5), 0)

    # Step 2: Apply Otsu's Thresholding (better for noise removal)
    _, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Step 3: Morphological Operations to remove noise and fill gaps
    kernel = np.ones((3, 3), np.uint8)
    img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel, iterations=1)  # Remove small noise
    img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel, iterations=1)  # Fill small gaps

    # Step 4: Find bounding box of equation (ignore extra white space)
    coords = cv2.findNonZero(255 - img)  # Invert colors and find non-zero pixels
    if coords is None:
        return np.ones(target_size, dtype=np.uint8) * 255  # Return blank image

    x, y, w, h = cv2.boundingRect(coords)

    # Step 5: Expand bounding box slightly to avoid tight cropping
    pad = 15
    x = max(0, x - pad)
    y = max(0, y - pad)
    w = min(img.shape[1] - x, w + 2 * pad)
    h = min(img.shape[0] - y, h + 2 * pad)

    img_cropped = img[y:y+h, x:x+w]

    # Step 6: Resize while keeping aspect ratio
    h, w = img_cropped.shape
    target_w, target_h = target_size
    scale = min(target_w / w, target_h / h)
    new_w, new_h = int(w * scale), int(h * scale)

    img_resized = cv2.resize(img_cropped, (new_w, new_h), interpolation=cv2.INTER_AREA)

    # Step 7: Create a white canvas and center the image
    canvas = np.ones((target_h, target_w), dtype=np.uint8) * 255
    pad_x = (target_w - new_w) // 2
    pad_y = (target_h - new_h) // 2
    canvas[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = img_resized

    return canvas



In [67]:
image_input_folder = "dataset/crohme2023/IMG/train/CROHME2019"
image_output_folder = "dataset/crohme2023/IMG/train/CROHME2019_preprocessed"
labels_input_folder = "dataset/crohme2023/INKML/train/CROHME2019" 
labels_output_folder = "dataset/crohme2023/labels/labels_2019"  

In [69]:
os.makedirs(image_output_folder, exist_ok=True)
os.makedirs(labels_output_folder, exist_ok=True)
count = 0
for img in os.listdir(image_input_folder):
    preprocessed_image = preprocess_image(os.path.join(image_input_folder, img))
    cv2.imwrite(os.path.join(image_output_folder, img), preprocessed_image)
    inkml_path = os.path.join(labels_input_folder, img.replace(".png", ".inkml"))
    label = extract_label_from_inkml(inkml_path)
    with open(os.path.join(labels_output_folder, img.replace(".png", ".txt")), "w", encoding="utf-8") as f:
        f.write(label)
    count += 1
    if count % 100 == 0:
        print(f"Processed {count} images")


Processed 100 images
Processed 200 images
Processed 300 images
Processed 400 images
Processed 500 images
Processed 600 images
Processed 700 images
Processed 800 images
Processed 900 images
Processed 1000 images
Processed 1100 images
Processed 1200 images
Processed 1300 images
Processed 1400 images
Processed 1500 images
Processed 1600 images
Processed 1700 images
Processed 1800 images
Processed 1900 images
Processed 2000 images
Processed 2100 images
Processed 2200 images
Processed 2300 images
Processed 2400 images
Processed 2500 images
Processed 2600 images
Processed 2700 images
Processed 2800 images
Processed 2900 images
Processed 3000 images
Processed 3100 images
Processed 3200 images
Processed 3300 images
Processed 3400 images
Processed 3500 images
Processed 3600 images
Processed 3700 images
Processed 3800 images
Processed 3900 images
Processed 4000 images
Processed 4100 images
Processed 4200 images
Processed 4300 images
Processed 4400 images
Processed 4500 images
Processed 4600 imag

In [79]:
def tokeniseEquation(equation):
    """
    Tokenize a mathematical equation into individual tokens.
    
    Args:
        equation (str): The equation string to tokenize.
    
    Returns:
        list: A list of tokenized elements (e.g., operators, numbers, variables, LaTeX commands).
    """
    # Step 1: Normalize the equation
    equation = equation.strip().lower()
    
    # Step 2: Add spaces around digits to separate them
    equation = re.sub(r"(\d)", r" \1 ", equation)
    
    # Step 3: Tokenize the equation
    # Updated pattern to include <, >, <=, >=, and other comparison operators
    tokenized_equation = re.findall(
        r"\\[a-zA-Z]+|[a-z]|\d+|[\^_={}\[\]\(\)+\-*/×<>≤≥≠]", 
        equation
    )
    
    return tokenized_equation

# Example usage
if __name__ == "__main__":
    # Test cases
    test_equations = [
        "x + 2 < 3",
        "y^2 >= 4",
        "a <= b",
        "5 > 2",
        "x = 3 ≠ 4",
        "\\frac{1}{2} + x",
        "\sqrt{2}\sqrt{2} = 2",
        "$0 < x < 1$",
        "$T \in E$"
    ]
    
    for eq in test_equations:
        tokens = tokeniseEquation(eq)
        print(f"Equation: {eq}")
        print(f"Tokens: {tokens}\n")

Equation: x + 2 < 3
Tokens: ['x', '+', '2', '<', '3']

Equation: y^2 >= 4
Tokens: ['y', '^', '2', '>', '=', '4']

Equation: a <= b
Tokens: ['a', '<', '=', 'b']

Equation: 5 > 2
Tokens: ['5', '>', '2']

Equation: x = 3 ≠ 4
Tokens: ['x', '=', '3', '≠', '4']

Equation: \frac{1}{2} + x
Tokens: ['\\frac', '{', '1', '}', '{', '2', '}', '+', 'x']

Equation: \sqrt{2}\sqrt{2} = 2
Tokens: ['\\sqrt', '{', '2', '}', '\\sqrt', '{', '2', '}', '=', '2']

Equation: $0 < x < 1$
Tokens: ['0', '<', 'x', '<', '1']

Equation: $T \in E$
Tokens: ['t', '\\in', 'e']



  "\sqrt{2}\sqrt{2} = 2",
  "$T \in E$"


In [81]:
import glob
import json
import os
from typing import Dict, List, Set
import logging

# Configure logging for better debugging and monitoring
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def build_vocabulary(label_dir: str, output_vocab_file: str, tokenise_func) -> Dict[str, int]:
    """
    Build a vocabulary from equation labels and save it to a JSON file.

    Args:
        label_dir (str): Directory containing label files (e.g., 'dataset/crohme2023/labels/labels_2019/*.txt').
        output_vocab_file (str): Path to save the vocabulary JSON file.
        tokenise_func (callable): Function to tokenize equations (e.g., tokeniseEquation).

    Returns:
        Dict[str, int]: A dictionary mapping tokens to their indices.
    """
    try:
        # Step 1: Collect all label file paths
        label_paths = sorted(glob.glob(os.path.join(label_dir, "*.txt")))
        if not label_paths:
            raise ValueError(f"No label files found in directory: {label_dir}")
        logger.info(f"Found {len(label_paths)} label files in {label_dir}")

        # Step 2: Define a set to store unique tokens
        tokens: Set[str] = set()

        # Step 3: Read and tokenize each label file
        for label_path in label_paths:
            try:
                with open(label_path, "r", encoding="utf-8") as f:
                    equation = f.read().strip()
                    if not equation:
                        logger.warning(f"Empty equation in file: {label_path}")
                        continue
                    tokenized_equation = tokenise_func(equation)
                    if not tokenized_equation:
                        logger.warning(f"No tokens produced for equation in file: {label_path}")
                        continue
                    tokens.update(tokenized_equation)
            except UnicodeDecodeError as e:
                logger.error(f"Encoding error in file {label_path}: {e}")
                continue
            except Exception as e:
                logger.error(f"Error processing file {label_path}: {e}")
                continue

        if not tokens:
            raise ValueError("No valid tokens found in any label files")

        logger.info(f"Extracted {len(tokens)} unique tokens")

        # Step 4: Create token-to-index mapping
        # Sort tokens for deterministic ordering
        sorted_tokens = sorted(tokens)
        
        # Define special tokens
        special_tokens = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"]
        token_to_index = {token: idx for idx, token in enumerate(special_tokens)}
        
        # Add regular tokens starting after special tokens
        for idx, token in enumerate(sorted_tokens, len(special_tokens)):
            token_to_index[token] = idx

        # Step 5: Create index-to-token mapping for decoding
        index_to_token = {idx: token for token, idx in token_to_index.items()}

        # Step 6: Save vocabulary
        vocab = {
            "token_to_index": token_to_index,
            "index_to_token": index_to_token
        }
        with open(output_vocab_file, "w", encoding="utf-8") as f:
            json.dump(vocab, f, indent=4, ensure_ascii=False)
        
        logger.info(f"Saved vocabulary to {output_vocab_file}")
        logger.info(f"Vocabulary size: {len(token_to_index)} (including {len(special_tokens)} special tokens)")

        return token_to_index

    except Exception as e:
        logger.error(f"Failed to build vocabulary: {e}")
        raise

# Example usage
if __name__ == "__main__":
    # Define paths
    label_dir = "dataset/crohme2023/labels/labels_2019"
    output_vocab_file = "vocab.json"

    # Build vocabulary
    token_to_index = build_vocabulary(label_dir, output_vocab_file, tokeniseEquation)

2025-03-08 21:25:42,339 - INFO - Found 10979 label files in dataset/crohme2023/labels/labels_2019
2025-03-08 21:25:43,796 - INFO - Extracted 105 unique tokens
2025-03-08 21:25:43,798 - INFO - Saved vocabulary to vocab.json
2025-03-08 21:25:43,798 - INFO - Vocabulary size: 109 (including 4 special tokens)


In [84]:
def load_vocabulary(vocab_file: str) -> tuple[Dict[str, int], Dict[int, str]]:
    """
    Load the vocabulary from a JSON file.

    Args:
        vocab_file (str): Path to the vocabulary JSON file.

    Returns:
        tuple: (token_to_index, index_to_token) mappings.
    """
    with open(vocab_file, "r", encoding="utf-8") as f:
        vocab = json.load(f)
    return vocab["token_to_index"], vocab["index_to_token"]

token_to_index, index_to_token = load_vocabulary("vocab.json")

In [85]:
import glob
import json
import os
import re
from typing import Dict, List, Optional
import logging

# Configure logging for better debugging and monitoring
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def tokenize_latex(equation: str, token_to_index: Dict[str, int], verbose: bool = False) -> List[int]:
    """
    Tokenize a LaTeX equation and convert tokens to indices.

    Args:
        equation (str): The LaTeX equation to tokenize.
        token_to_index (Dict[str, int]): Mapping of tokens to their indices.
        verbose (bool): Whether to log the tokenized equation for debugging.

    Returns:
        List[int]: List of token indices.
    """
    equation = equation.strip().lower()  # Convert to lowercase for consistency

    # Extract LaTeX commands, letters, digits, operators, and comparison symbols
    tokens = re.findall(
        r"\\[a-zA-Z]+|[a-z]|\d+|[\^_={}\[\]\(\)+\-*/×<>≤≥≠]", 
        equation
    )

    if verbose:
        logger.info(f"Tokenized equation '{equation}': {tokens}")

    # Convert tokens to indices, using <UNK> for unknown tokens
    indices = [token_to_index.get(token, token_to_index["<UNK>"]) for token in tokens]
    return indices

def pad_sequence(sequence: List[int], max_length: int, token_to_index: Dict[str, int]) -> List[int]:
    """
    Pad a sequence to a fixed length, adding <SOS> and <EOS> tokens.

    Args:
        sequence (List[int]): List of token indices.
        max_length (int): Desired length of the padded sequence.
        token_to_index (Dict[str, int]): Mapping of tokens to their indices.

    Returns:
        List[int]: Padded sequence.
    """
    if len(sequence) > max_length - 2:
        logger.warning(f"Sequence truncated from {len(sequence)} to {max_length - 2} tokens")
        sequence = sequence[:max_length - 2]
    
    padded_sequence = [token_to_index["<SOS>"]] + sequence + [token_to_index["<EOS>"]]
    padded_sequence += [token_to_index["<PAD>"]] * (max_length - len(padded_sequence))
    return padded_sequence

def process_labels(labels_input_folder: str, tokenized_labels_output_folder: str, 
                   token_to_index: Dict[str, int], index_to_token: Dict[int, str], 
                   verbose: bool = False) -> None:
    """
    Process label files by tokenizing, padding, and saving the results.

    Args:
        labels_input_folder (str): Directory containing label files.
        tokenized_labels_output_folder (str): Directory to save tokenized and padded labels.
        token_to_index (Dict[str, int]): Mapping of tokens to their indices.
        index_to_token (Dict[int, str]): Mapping of indices to their tokens.
        verbose (bool): Whether to log tokenized equations for debugging.
    """
    try:
        # Create output directory if it doesn't exist
        os.makedirs(tokenized_labels_output_folder, exist_ok=True)
        logger.info(f"Output directory: {tokenized_labels_output_folder}")

        # Collect all label file paths
        label_files = sorted(glob.glob(os.path.join(labels_input_folder, "*.txt")))
        if not label_files:
            raise ValueError(f"No label files found in directory: {labels_input_folder}")
        logger.info(f"Found {len(label_files)} label files")

        # Step 1: Tokenize all labels and determine max length
        tokenized_labels = []
        max_length = 0
        max_label_info = None

        for label_file in label_files:
            try:
                with open(label_file, "r", encoding="utf-8") as f:
                    label_content = f.read().strip()
                
                if not label_content:
                    logger.warning(f"Empty label file: {label_file}")
                    continue
                
                tokenized_label = tokenize_latex(label_content, token_to_index, verbose=verbose)
                if not tokenized_label:
                    logger.warning(f"No tokens produced for label file: {label_file}")
                    continue
                
                tokenized_labels.append((label_file, label_content, tokenized_label))
                
                current_length = len(tokenized_label)
                if current_length > max_length:
                    max_length = current_length
                    max_label_info = (label_file, label_content)
            except UnicodeDecodeError as e:
                logger.error(f"Encoding error in file {label_file}: {e}")
                continue
            except Exception as e:
                logger.error(f"Error processing file {label_file}: {e}")
                continue

        if not tokenized_labels:
            raise ValueError("No valid tokenized labels produced")

        # Account for <SOS> and <EOS> tokens in max_length
        max_length += 2
        logger.info(f"Max sequence length (including <SOS> and <EOS>): {max_length}")
        if max_label_info:
            logger.info(f"Longest label: '{max_label_info[1]}' from file {max_label_info[0]}")

        # Step 2: Process and save tokenized labels
        for idx, (label_file, label_content, tokenized_label) in enumerate(tokenized_labels):
            try:
                # Pad the tokenized label
                padded_label = pad_sequence(tokenized_label, max_length, token_to_index)

                # For debugging, decode the padded label back to tokens
                if verbose:
                    decoded_label = [index_to_token[str(idx)] for idx in padded_label]
                    logger.info(f"Padded label for {label_file}: {decoded_label}")

                # Save the tokenized and padded label
                tokenized_label_path = os.path.join(tokenized_labels_output_folder, os.path.basename(label_file))
                with open(tokenized_label_path, "w", encoding="utf-8") as f:
                    f.write(" ".join(map(str, padded_label)))

                logger.info(f"Processed {idx + 1}/{len(tokenized_labels)}: {os.path.basename(label_file)}")
            except Exception as e:
                logger.error(f"Error saving tokenized label for file {label_file}: {e}")
                continue

        logger.info("🎉 Tokenization and padding complete!")

    except Exception as e:
        logger.error(f"Failed to process labels: {e}")
        raise

# Example usage
if __name__ == "__main__":
    # Define paths
    # vocab_file = "vocab.json"
    labels_input_folder = "dataset/crohme2023/labels/labels_2019"
    tokenized_labels_output_folder = "dataset/crohme2023/labels/tokenized_labels_2019"

    # Load vocabulary
    # token_to_index, index_to_token = load_vocabulary(vocab_file)

    # Process labels
    process_labels(labels_input_folder, tokenized_labels_output_folder, token_to_index, index_to_token, verbose=False)

2025-03-08 21:31:43,917 - INFO - Output directory: dataset/crohme2023/labels/tokenized_labels_2019
2025-03-08 21:31:43,954 - INFO - Found 10979 label files
2025-03-08 21:31:45,485 - INFO - Max sequence length (including <SOS> and <EOS>): 198
2025-03-08 21:31:45,486 - INFO - Longest label: '$|x^{\frac{1}{n}} - c^{\frac{1}{n}}| = \frac{|x^{\frac{1}{n}} - c^{\frac{1}{n}}||x^{\frac{n-1}{n}} + x^{\frac{n-2}{n}}c^{\frac{1}{n}} + \cdots + x^{\frac{1}{n}}c^{\frac{n-2}{n}}|}{|x^{\frac{n-1}{n}} + x^{\frac{n-2}{n}}c^{\frac{1}{n}} + \cdots + x^{\frac{1}{n}}c^{\frac{n-2}{n}} + c^{\frac{n-1}{n}}|}$' from file dataset/crohme2023/labels/labels_2019/505_em_51.txt
2025-03-08 21:31:45,487 - INFO - Processed 1/10979: 001-equation000.txt
2025-03-08 21:31:45,488 - INFO - Processed 2/10979: 001-equation001.txt
2025-03-08 21:31:45,488 - INFO - Processed 3/10979: 001-equation0010.txt
2025-03-08 21:31:45,489 - INFO - Processed 4/10979: 001-equation0011.txt
2025-03-08 21:31:45,489 - INFO - Processed 5/10979: 001

In [89]:
import torch
from torch.utils.data import Dataset
import cv2
import os
import json
from glob import glob
import numpy as np
from typing import Dict, List, Tuple, Optional
import logging
from torchvision import transforms

# Configure logging for better debugging and monitoring
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MathEquationDataset(Dataset):
    def __init__(self, image_dir: str, tokenized_label_dir: str, vocab_file: str, 
                 max_seq_len: int, image_size: Tuple[int, int] = (256, 256), 
                 transform: Optional[transforms.Compose] = None) -> None:
        """
        PyTorch Dataset for loading handwritten equation images and their tokenized labels.

        Args:
            image_dir (str): Directory containing preprocessed image files.
            tokenized_label_dir (str): Directory containing tokenized label files.
            vocab_file (str): Path to the vocabulary JSON file.
            max_seq_len (int): Maximum sequence length for tokenized labels.
            image_size (Tuple[int, int]): Target size for preprocessed images.
            transform (transforms.Compose, optional): Additional image transformations.
        """
        self.image_dir = image_dir
        self.tokenized_label_dir = tokenized_label_dir
        self.max_seq_len = max_seq_len
        self.image_size = image_size
        self.transform = transform

        # Load vocabulary
        try:
            with open(vocab_file, "r", encoding="utf-8") as f:
                vocab = json.load(f)
            self.token_to_index = vocab.get("token_to_index", vocab)  # Handle both old and new vocab formats
            self.index_to_token = vocab.get("index_to_token", {str(v): k for k, v in self.token_to_index.items()})
            
            # Validate special tokens
            required_special_tokens = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"]
            missing_special_tokens = [token for token in required_special_tokens if token not in self.token_to_index]
            if missing_special_tokens:
                raise ValueError(f"Vocabulary is missing required special tokens: {missing_special_tokens}")
        except Exception as e:
            logger.error(f"Error loading vocabulary: {e}")
            raise

        # Get list of image and tokenized label files
        self.image_paths = sorted(glob(os.path.join(image_dir, "*.png")))
        self.tokenized_label_paths = sorted(glob(os.path.join(tokenized_label_dir, "*.txt")))

        if not self.image_paths or not self.tokenized_label_paths:
            raise ValueError(f"No images or labels found in directories: {image_dir}, {tokenized_label_dir}")

        # Validate file pairing
        self._validate_file_pairing()

    def _validate_file_pairing(self) -> None:
        """
        Validate that image and label files are paired correctly.
        """
        if len(self.image_paths) != len(self.tokenized_label_paths):
            raise ValueError(f"Mismatch between number of images ({len(self.image_paths)}) "
                             f"and labels ({len(self.tokenized_label_paths)})")

        for img_path, lbl_path in zip(self.image_paths, self.tokenized_label_paths):
            img_basename = os.path.splitext(os.path.basename(img_path))[0]
            lbl_basename = os.path.splitext(os.path.basename(lbl_path))[0]
            if img_basename != lbl_basename:
                raise ValueError(f"Mismatch between image and label file names: {img_path} vs {lbl_path}")

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Load and preprocess an image and its tokenized label.

        Args:
            idx (int): Index of the sample to load.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Preprocessed image tensor and tokenized label tensor.
        """
        try:
            # Load image
            image_path = self.image_paths[idx]
            image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            if image is None:
                raise ValueError(f"Failed to load image: {image_path}")


            # Convert to tensor and add channel dimension
            image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  # Shape: [1, H, W]

            # Apply additional transforms if provided
            if self.transform:
                image = self.transform(image)

            # Load tokenized label
            tokenized_label_path = self.tokenized_label_paths[idx]
            with open(tokenized_label_path, "r", encoding="utf-8") as f:
                tokenized_label = list(map(int, f.read().strip().split()))

            # Validate sequence length
            if len(tokenized_label) > self.max_seq_len:
                logger.warning(f"Truncating label sequence from {len(tokenized_label)} to {self.max_seq_len} tokens")
                tokenized_label = tokenized_label[:self.max_seq_len]

            # Convert to tensor
            tokenized_label = torch.tensor(tokenized_label, dtype=torch.long)

            return image, tokenized_label

        except Exception as e:
            logger.error(f"Error loading sample {idx} (image: {self.image_paths[idx]}, "
                         f"label: {self.tokenized_label_paths[idx]}): {e}")
            raise

# Example usage
if __name__ == "__main__":
    # Define image transforms (optional)
    image_transforms = transforms.Compose([
        transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
    ])

    dataset = MathEquationDataset(
        image_dir="dataset/crohme2023/IMG/train/CROHME2019_preprocessed",
        tokenized_label_dir="dataset/crohme2023/labels/tokenized_labels_2019",
        vocab_file="vocab.json",
        max_seq_len=198,  # Use the actual max sequence length
        image_size=(256, 256),
        transform=image_transforms
    )

    # Test loading a sample
    image, tokenized_label = dataset[0]
    print(f"Image shape: {image.shape}, Tokenized label shape: {tokenized_label.shape}")

    # Decode the tokenized label for debugging
    decoded_label = [dataset.index_to_token[str(idx.item())] for idx in tokenized_label 
                     if idx.item() not in {dataset.token_to_index["<PAD>"], 
                                          dataset.token_to_index["<SOS>"], 
                                          dataset.token_to_index["<EOS>"]}]
    print(f"Decoded label: {decoded_label}")

Image shape: torch.Size([1, 256, 256]), Tokenized label shape: torch.Size([198])
Decoded label: ['y', '=', 'a', 'x', '+', 'a', '^', '2']


In [90]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import math
from typing import Dict, List, Optional, Tuple
import logging

# Configure logging for better debugging and monitoring
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def load_vocabulary(vocab_file: str) -> Tuple[Dict[str, int], Dict[int, str]]:
    """
    Load the vocabulary from a JSON file.

    Args:
        vocab_file (str): Path to the vocabulary JSON file.

    Returns:
        Tuple[Dict[str, int], Dict[int, str]]: (token_to_index, index_to_token) mappings.

    Raises:
        FileNotFoundError: If the vocabulary file does not exist.
        ValueError: If the vocabulary file is malformed or missing required special tokens.
    """
    try:
        with open(vocab_file, "r", encoding="utf-8") as f:
            vocab = json.load(f)
        token_to_index = vocab.get("token_to_index", vocab)  # Handle both old and new vocab formats
        index_to_token = vocab.get("index_to_token", {str(v): k for k, v in token_to_index.items()})
        
        # Validate special tokens
        required_special_tokens = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"]
        missing_special_tokens = [token for token in required_special_tokens if token not in token_to_index]
        if missing_special_tokens:
            raise ValueError(f"Vocabulary is missing required special tokens: {missing_special_tokens}")
        
        return token_to_index, index_to_token
    except FileNotFoundError:
        logger.error(f"Vocabulary file not found: {vocab_file}")
        raise
    except json.JSONDecodeError:
        logger.error(f"Vocabulary file is malformed: {vocab_file}")
        raise
    except Exception as e:
        logger.error(f"Error loading vocabulary: {e}")
        raise

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape: [1, max_len, embed_dim]
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Add positional encoding to the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, seq_len, embed_dim].

        Returns:
            torch.Tensor: Tensor with positional encoding added.
        """
        x = x + self.pe[:, :x.size(1), :]
        return x

class CNNEncoder(nn.Module):
    def __init__(self, image_size: Tuple[int, int] = (256, 256), embed_dim: int = 256):
        """
        CNN encoder for extracting features from handwritten equation images.

        Args:
            image_size (Tuple[int, int]): Expected input image size (H, W).
            embed_dim (int): Dimension of the output embedding.
        """
        super(CNNEncoder, self).__init__()
        self.image_size = image_size
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        
        # Calculate the size of the flattened feature map
        h, w = image_size
        h, w = h // 4, w // 4  # Two max pooling layers
        self.flatten_dim = 64 * h * w
        self.fc = nn.Linear(self.flatten_dim, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the CNN encoder.

        Args:
            x (torch.Tensor): Input images of shape [B, 1, H, W].

        Returns:
            torch.Tensor: Encoded features of shape [B, embed_dim].
        """
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)  # [B, 64, H/4, W/4]
        x = x.view(x.size(0), -1)  # Flatten: [B, 64*H/4*W/4]
        x = self.fc(x)  # [B, embed_dim]
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int = 256, num_heads: int = 4, 
                 num_layers: int = 2, max_seq_len: int = 200, dropout: float = 0.1):
        """
        Transformer decoder for generating tokenized equation sequences.

        Args:
            vocab_size (int): Size of the vocabulary.
            embed_dim (int): Dimension of the embeddings.
            num_heads (int): Number of attention heads.
            num_layers (int): Number of decoder layers.
            max_seq_len (int): Maximum sequence length for positional encoding.
            dropout (float): Dropout rate.
        """
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim, max_seq_len)
        self.decoder_layer = nn.TransformerDecoderLayer(embed_dim, num_heads, 
                                                        dim_feedforward=embed_dim * 4, 
                                                        dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers)
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self.embed_dim = embed_dim

    def forward(self, tgt: torch.Tensor, memory: torch.Tensor, 
                tgt_mask: Optional[torch.Tensor] = None, 
                tgt_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass of the Transformer decoder.

        Args:
            tgt (torch.Tensor): Target sequence of shape [B, T].
            memory (torch.Tensor): Encoded image features of shape [B, embed_dim].
            tgt_mask (torch.Tensor, optional): Mask for target sequence to prevent attending to future tokens.
            tgt_key_padding_mask (torch.Tensor, optional): Mask for padding tokens in the target sequence.

        Returns:
            torch.Tensor: Output logits of shape [B, T, vocab_size].
        """
        tgt_embed = self.embedding(tgt) * math.sqrt(self.embed_dim)
        tgt_embed = self.pos_encoder(tgt_embed)
        memory = memory.unsqueeze(1)  # [B, 1, embed_dim] for cross-attention
        output = self.decoder(tgt_embed, memory, tgt_mask=tgt_mask, 
                              tgt_key_padding_mask=tgt_key_padding_mask)
        output = self.fc_out(output)
        return output

class MathEquationModel(nn.Module):
    def __init__(self, vocab_size: int, image_size: Tuple[int, int] = (256, 256), 
                 embed_dim: int = 256, num_heads: int = 4, num_layers: int = 2, 
                 max_seq_len: int = 200, dropout: float = 0.1):
        """
        End-to-end model for handwritten equation recognition.

        Args:
            vocab_size (int): Size of the vocabulary.
            image_size (Tuple[int, int]): Expected input image size (H, W).
            embed_dim (int): Dimension of the embeddings.
            num_heads (int): Number of attention heads in the decoder.
            num_layers (int): Number of decoder layers.
            max_seq_len (int): Maximum sequence length for positional encoding.
            dropout (float): Dropout rate.
        """
        super(MathEquationModel, self).__init__()
        self.encoder = CNNEncoder(image_size, embed_dim)
        self.decoder = TransformerDecoder(vocab_size, embed_dim, num_heads, num_layers, 
                                          max_seq_len, dropout)
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len

    def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        """
        Generate a square subsequent mask for the target sequence.

        Args:
            sz (int): Size of the sequence.

        Returns:
            torch.Tensor: Mask of shape [sz, sz] with -inf on the upper triangle.
        """
        mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
        return mask

    def generate_padding_mask(self, tgt: torch.Tensor, pad_idx: int) -> torch.Tensor:
        """
        Generate a padding mask for the target sequence.

        Args:
            tgt (torch.Tensor): Target sequence of shape [B, T].
            pad_idx (int): Index of the padding token.

        Returns:
            torch.Tensor: Padding mask of shape [B, T] with True for padding tokens.
        """
        return tgt == pad_idx

    def forward(self, image: torch.Tensor, tgt: torch.Tensor, pad_idx: int) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            image (torch.Tensor): Input images of shape [B, 1, H, W].
            tgt (torch.Tensor): Target sequence of shape [B, T].
            pad_idx (int): Index of the padding token.

        Returns:
            torch.Tensor: Output logits of shape [B, T, vocab_size].
        """
        memory = self.encoder(image)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(image.device)
        tgt_key_padding_mask = self.generate_padding_mask(tgt, pad_idx).to(image.device)
        output = self.decoder(tgt, memory, tgt_mask, tgt_key_padding_mask)
        return output

    def predict(self, image: torch.Tensor, token_to_index: Dict[str, int], 
                index_to_token: Dict[int, str], max_len: Optional[int] = None, 
                device: str = "cpu") -> List[str]:
        """
        Predict a tokenized equation sequence from an image.

        Args:
            image (torch.Tensor): Input image of shape [1, 1, H, W].
            token_to_index (Dict[str, int]): Mapping of tokens to their indices.
            index_to_token (Dict[int, str]): Mapping of indices to their tokens.
            max_len (int, optional): Maximum length of the predicted sequence.
            device (str): Device to run inference on.

        Returns:
            List[str]: Predicted token sequence.
        """
        self.eval()
        max_len = max_len or self.max_seq_len
        sos_idx = token_to_index["<SOS>"]
        eos_idx = token_to_index["<EOS>"]
        pad_idx = token_to_index["<PAD>"]

        with torch.no_grad():
            memory = self.encoder(image)
            tgt = torch.tensor([[sos_idx]], dtype=torch.long, device=device)
            outputs = []
            for _ in range(max_len):
                tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(device)
                output = self.decoder(tgt, memory, tgt_mask)
                next_token = output[:, -1, :].argmax(dim=-1)
                token_idx = next_token.item()
                if token_idx == eos_idx:
                    break
                if token_idx != pad_idx:  # Ignore padding tokens in output
                    outputs.append(index_to_token[str(token_idx)])
                tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)
            return outputs

    def save(self, path: str) -> None:
        """
        Save the model state to a file.

        Args:
            path (str): Path to save the model state.
        """
        torch.save(self.state_dict(), path)
        logger.info(f"Model saved to {path}")

    def load(self, path: str, device: str = "cpu") -> None:
        """
        Load the model state from a file.

        Args:
            path (str): Path to the model state file.
            device (str): Device to load the model on.
        """
        self.load_state_dict(torch.load(path, map_location=device))
        logger.info(f"Model loaded from {path}")

# Example usage
if __name__ == "__main__":
    # Load vocabulary
    vocab_file = "vocab.json"
    token_to_index, index_to_token = load_vocabulary(vocab_file)
    vocab_size = len(token_to_index)

    # Initialize model
    model = MathEquationModel(
        vocab_size=vocab_size,
        image_size=(256, 256),
        embed_dim=256,
        num_heads=4,
        num_layers=2,
        max_seq_len=198,  # Use the actual max sequence length
        dropout=0.1
    )

    # Print model parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params}")

    # Test forward pass with dummy data
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    batch_size = 2
    dummy_image = torch.randn(batch_size, 1, 256, 256, device=device)
    dummy_tgt = torch.randint(0, vocab_size, (batch_size, 10), device=device)
    pad_idx = token_to_index["<PAD>"]
    output = model(dummy_image, dummy_tgt, pad_idx)
    print(f"Output shape: {output.shape}")

    # Test prediction with dummy image
    dummy_image = torch.randn(1, 1, 256, 256, device=device)
    predicted_sequence = model.predict(dummy_image, token_to_index, index_to_token, device=device)
    print(f"Predicted sequence: {predicted_sequence}")

Model parameters: 70344173
Output shape: torch.Size([2, 10, 109])
Predicted sequence: ['3', '<', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '\\gtm', '\\exists', '_', '9', '1', '\\neq', 'l', '\\beta', 'u', '\\alpha', '\\sqrt', '\\gt',

In [96]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import logging
import os
from typing import Dict, Tuple, Optional
# from dataset import MathEquationDataset  # Assuming the dataset code is in dataset.py
# from model import MathEquationModel, load_vocabulary  # Assuming the model code is in model.py

# Configure logging for better debugging and monitoring
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def train_model(model: nn.Module, train_dataset: MathEquationDataset, 
                val_dataset: Optional[MathEquationDataset] = None, num_epochs: int = 10, 
                batch_size: int = 16, device: str = "cuda", learning_rate: float = 0.0001, 
                model_save_path: str = "math_equation_model.pth", patience: int = 5) -> Dict[str, float]:
    """
    Train the model on the dataset.

    Args:
        model (nn.Module): The model to train.
        train_dataset (MathEquationDataset): Training dataset.
        val_dataset (MathEquationDataset, optional): Validation dataset.
        num_epochs (int): Number of epochs to train.
        batch_size (int): Batch size for training.
        device (str): Device to train on ('cuda', 'cpu', or 'mps').
        learning_rate (float): Initial learning rate.
        model_save_path (str): Path to save the best model.
        patience (int): Patience for early stopping.

    Returns:
        Dict[str, float]: Dictionary containing the best validation metrics.
    """
    try:
        # Move model to device
        device = torch.device(device if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        logger.info(f"Training on device: {device}")

        # Create dataloaders
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        if val_dataset:
            val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
        logger.info(f"Training dataset size: {len(train_dataset)}, Validation dataset size: {len(val_dataset) if val_dataset else 0}")

        # Initialize optimizer and loss function
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
        pad_idx = token_to_index["<PAD>"]
        criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
        logger.info(f"Using padding index: {pad_idx}")

        # Training loop
        best_val_loss = float('inf')
        epochs_no_improve = 0
        best_metrics = {}

        for epoch in range(num_epochs):
            # Training phase
            model.train()
            total_train_loss = 0
            for batch_idx, (images, labels) in enumerate(train_dataloader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images, labels[:, :-1], pad_idx)  # Exclude <EOS> from input
                loss = criterion(outputs.reshape(-1, model.vocab_size), labels[:, 1:].reshape(-1))  # Exclude <SOS> from target
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
                optimizer.step()
                total_train_loss += loss.item()
                if batch_idx % 10 == 0:
                    logger.info(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_dataloader)}, Loss: {loss.item():.4f}")

            avg_train_loss = total_train_loss / len(train_dataloader)
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Average Training Loss: {avg_train_loss:.4f}")

            # Validation phase
            if val_dataset:
                model.eval()
                total_val_loss = 0
                with torch.no_grad():
                    for images, labels in val_dataloader:
                        images, labels = images.to(device), labels.to(device)
                        outputs = model(images, labels[:, :-1], pad_idx)
                        loss = criterion(outputs.reshape(-1, model.vocab_size), labels[:, 1:].reshape(-1))
                        total_val_loss += loss.item()

                avg_val_loss = total_val_loss / len(val_dataloader)
                logger.info(f"Epoch {epoch+1}/{num_epochs}, Average Validation Loss: {avg_val_loss:.4f}")

                # Learning rate scheduling
                scheduler.step(avg_val_loss)

                # Early stopping
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    epochs_no_improve = 0
                    best_metrics = {"best_val_loss": best_val_loss, "epoch": epoch + 1}
                    try:
                        model.save(model_save_path)
                    except Exception as e:
                        logger.error(f"Error saving model: {e}")
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= patience:
                        logger.info(f"Early stopping triggered after {epoch+1} epochs")
                        break

        logger.info(f"Best validation loss: {best_metrics.get('best_val_loss', float('inf')):.4f} at epoch {best_metrics.get('epoch', 0)}")
        return best_metrics

    except Exception as e:
        logger.error(f"Error during training: {e}")
        raise

def split_dataset(dataset: MathEquationDataset, val_split: float = 0.1) -> Tuple[MathEquationDataset, MathEquationDataset]:
    """
    Split the dataset into training and validation sets.

    Args:
        dataset (MathEquationDataset): Full dataset.
        val_split (float): Fraction of the dataset to use for validation.

    Returns:
        Tuple[MathEquationDataset, MathEquationDataset]: Training and validation datasets.
    """
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(val_split * dataset_size)
    train_indices, val_indices = indices[split:], indices[:split]
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    return train_dataset, val_dataset

if __name__ == "__main__":
    # Define paths
    image_dir = "dataset/crohme2023/IMG/train/CROHME2019_preprocessed"
    tokenized_label_dir = "dataset/crohme2023/labels/tokenized_labels_2019"
    vocab_file = "vocab.json"
    model_save_path = "math_equation_model.pth"

    # Load vocabulary
    # token_to_index, index_to_token = load_vocabulary(vocab_file)
    vocab_size = len(token_to_index)

    # Initialize dataset
    try:
        full_dataset = MathEquationDataset(
            image_dir=image_dir,
            tokenized_label_dir=tokenized_label_dir,
            vocab_file=vocab_file,
            max_seq_len=198,  # Use the actual max sequence length
            image_size=(256, 256)
        )
    except Exception as e:
        logger.error(f"Error initializing dataset: {e}")
        raise

    # Split dataset into training and validation
    train_dataset, val_dataset = split_dataset(full_dataset, val_split=0.1)

    # Initialize model
    try:
        model = MathEquationModel(
            vocab_size=vocab_size,
            image_size=(256, 256),
            embed_dim=256,
            num_heads=4,
            num_layers=2,
            max_seq_len=198,
            dropout=0.1
        )
    except Exception as e:
        logger.error(f"Error initializing model: {e}")
        raise

    # Train the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_model(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        num_epochs=10,
        batch_size=16,
        device=device,
        learning_rate=0.0001,
        model_save_path=model_save_path,
        patience=5
    )

    logger.info("Training complete!")

2025-03-08 22:03:16,758 - INFO - Training on device: cpu
2025-03-08 22:03:16,759 - INFO - Training dataset size: 9882, Validation dataset size: 1097
2025-03-08 22:03:16,760 - INFO - Using padding index: 0
2025-03-08 22:03:18,340 - INFO - Epoch 1/10, Batch 0/618, Loss: 4.9471
2025-03-08 22:03:28,124 - INFO - Epoch 1/10, Batch 10/618, Loss: 3.7911
2025-03-08 22:03:36,331 - INFO - Epoch 1/10, Batch 20/618, Loss: 3.7948
2025-03-08 22:03:45,337 - INFO - Epoch 1/10, Batch 30/618, Loss: 3.6030
2025-03-08 22:03:53,625 - INFO - Epoch 1/10, Batch 40/618, Loss: 3.6406
2025-03-08 22:04:01,108 - INFO - Epoch 1/10, Batch 50/618, Loss: 3.4692
2025-03-08 22:04:09,489 - INFO - Epoch 1/10, Batch 60/618, Loss: 3.8480
2025-03-08 22:04:18,220 - INFO - Epoch 1/10, Batch 70/618, Loss: 3.5745
2025-03-08 22:04:26,121 - INFO - Epoch 1/10, Batch 80/618, Loss: 3.6827
2025-03-08 22:04:34,546 - INFO - Epoch 1/10, Batch 90/618, Loss: 3.3641
2025-03-08 22:04:43,225 - INFO - Epoch 1/10, Batch 100/618, Loss: 3.4116
202

In [98]:
import sympy as sp

def parse_tokens(tokens: list) -> tuple:
    """
    Parse a list of tokens into a SymPy equation.

    Args:
        tokens (list): List of tokens from the model (e.g., ['x', '+', '2', '=', '5']).

    Returns:
        tuple: (left_expr, right_expr) as SymPy expressions.
    """
    # Define symbols
    variables = {t: sp.Symbol(t) for t in tokens if t.isalpha() and t not in ['sin', 'cos', 'tan', 'log']}

    # Split at '=' to get left and right sides
    if '=' not in tokens:
        raise ValueError("No '=' found in equation")
    eq_idx = tokens.index('=')
    left_tokens = tokens[:eq_idx]
    right_tokens = tokens[eq_idx + 1:]

    # Convert token lists to SymPy expressions
    def tokens_to_expr(token_list):
        expr_str = ' '.join(token_list).replace('^', '**')  # Handle exponents
        return sp.sympify(expr_str, locals=variables)

    left_expr = tokens_to_expr(left_tokens)
    right_expr = tokens_to_expr(right_tokens)
    return left_expr, right_expr

# # Example usage
# if __name__ == "__main__":
#     model = MathEquationModel(vocab_size=len(token_to_index), max_seq_len=198)
#     model.load("math_equation_model.pth", device="cpu")
    
#     # Dummy prediction for testing
#     dummy_image = torch.randn(1, 1, 256, 256)
#     tokens = model.predict(dummy_image, token_to_index, index_to_token, device="cpu")
#     print(f"Predicted tokens: {tokens}")
    
#     left_expr, right_expr = parse_tokens(tokens)
#     print(f"Parsed equation: {left_expr} = {right_expr}")

In [99]:
def solve_equation(left_expr, right_expr, equation_type="algebra") -> dict:
    """
    Solve the equation and return step-by-step solutions.

    Args:
        left_expr: SymPy expression for the left side.
        right_expr: SymPy expression for the right side.
        equation_type (str): Type of equation ('algebra', 'calculus_deriv', 'calculus_int', 'linear_system', 'diff_eq').

    Returns:
        dict: Solutions and steps.
    """
    x = sp.Symbol('x')
    steps = []

    if equation_type == "algebra":
        eq = sp.Eq(left_expr, right_expr)
        solutions = sp.solve(eq, x)
        steps.append(f"Start with: {eq}")
        steps.append(f"Solve for x: {solutions}")
    elif equation_type == "calculus_deriv":
        deriv = sp.diff(left_expr, x)
        steps.append(f"Differentiate {left_expr} with respect to x: {deriv}")
        solutions = deriv
    elif equation_type == "calculus_int":
        integral = sp.integrate(left_expr, x)
        steps.append(f"Integrate {left_expr} with respect to x: {integral} + C")
        solutions = integral
    elif equation_type == "linear_system":
        # Expect left_expr and right_expr to be lists for systems
        eqs = [sp.Eq(l, r) for l, r in zip(left_expr, right_expr)]
        solutions = sp.solve(eqs)
        steps.append(f"Solve system: {eqs}")
        steps.append(f"Solution: {solutions}")
    elif equation_type == "diff_eq":
        y = sp.Function('y')(x)
        eq = sp.Eq(left_expr, right_expr)
        solutions = sp.dsolve(eq, y)
        steps.append(f"Differential equation: {eq}")
        steps.append(f"Solution: {solutions}")

    return {"solutions": solutions, "steps": steps}

# Example usage
if __name__ == "__main__":
    left_expr, right_expr = parse_tokens(['x', '+', '2', '=', '5'])
    result = solve_equation(left_expr, right_expr, "algebra")
    print(f"Solutions: {result['solutions']}")
    print("Steps:")
    for step in result["steps"]:
        print(step)

Solutions: [3]
Steps:
Start with: Eq(x + 2, 5)
Solve for x: [3]


In [100]:
def detailed_solve_algebra(left_expr, right_expr) -> list:
    """Generate detailed steps for algebraic equations."""
    x = sp.Symbol('x')
    eq = sp.Eq(left_expr, right_expr)
    steps = [f"Given: {eq}"]

    # Isolate x step-by-step
    if left_expr.has(x):
        steps.append(f"Subtract {left_expr - x} from both sides:")
        new_eq = sp.Eq(x, right_expr - (left_expr - x))
        steps.append(f"Result: {new_eq}")
        solutions = sp.solve(new_eq, x)
        steps.append(f"Simplify: x = {solutions[0]}")
    return steps

# Example usage
if __name__ == "__main__":
    left_expr, right_expr = parse_tokens(['x', '+', '2', '=', '5'])
    steps = detailed_solve_algebra(left_expr, right_expr)
    for step in steps:
        print(step)

Given: Eq(x + 2, 5)
Subtract 2 from both sides:
Result: Eq(x, 3)
Simplify: x = 3


In [101]:
def explain_steps(steps: list) -> list:
    """Convert symbolic steps into plain English."""
    explanations = []
    for step in steps:
        if "Given" in step:
            explanations.append(f"We start with the equation {step.split(': ')[1]}.")
        elif "Subtract" in step:
            term = step.split("Subtract ")[1].split(" from")[0]
            explanations.append(f"Next, we subtract {term} from both sides to isolate the variable.")
        elif "Result" in step:
            explanations.append(f"This gives us: {step.split(': ')[1]}.")
        elif "Simplify" in step:
            explanations.append(f"Finally, we simplify to find that {step.split(': ')[1]}.")
    return explanations

# Example usage
if __name__ == "__main__":
    steps = detailed_solve_algebra(left_expr, right_expr)
    explanations = explain_steps(steps)
    for exp in explanations:
        print(exp)

We start with the equation Eq(x + 2, 5).
Next, we subtract 2 from both sides to isolate the variable.
This gives us: Eq(x, 3).
Finally, we simplify to find that x = 3.


In [107]:
import torch
import cv2

def process_image(image_path: str, model, token_to_index, index_to_token) -> list:
    # image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    image = preprocess_image(image_path, target_size=(256, 256))
    image = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
    tokens = model.predict(image, token_to_index, index_to_token, device="cpu")
    return tokens

def solve_and_explain(image_path: str, model, token_to_index, index_to_token):
    tokens = process_image(image_path, model, token_to_index, index_to_token)
    print(f"Recognized equation: {' '.join(tokens)}")
    
    left_expr, right_expr = parse_tokens(tokens)
    steps = detailed_solve_algebra(left_expr, right_expr)
    explanations = explain_steps(steps)
    
    print("\nStep-by-Step Solution:")
    for exp in explanations:
        print(exp)

if __name__ == "__main__":
    # Load model and vocabulary
    # token_to_index, index_to_token = load_vocabulary("vocab.json")
    model = MathEquationModel(vocab_size=len(token_to_index), max_seq_len=198)
    model.load("math_equation_model.pth", device="cpu")
    
    # Test with a photo
    image_path = "/Users/hardikdudeja/Desktop/img.png" 
    solve_and_explain(image_path, model, token_to_index, index_to_token)

2025-03-08 23:51:17,799 - INFO - Model loaded from math_equation_model.pth


Recognized equation: { { } } } }


ValueError: No '=' found in equation