# Project detection math expression

## Imports

In [40]:
from skimage import io, transform
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, utils
from collections import Counter
from PIL import Image
#from tqdm import tqdm

import os
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models.detection as detection

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()  

<contextlib.ExitStack at 0x7ef417033390>

## Pre treatement

In [None]:
class CROHMEDataset(Dataset):
    """
    Dataset pour les expressions complètes (PNG + LG).
    Chaque sample retourne :
        - image : Tensor CxHxW
        - target : dict contenant "boxes" et "labels"
    """

    def __init__(self, root, transform=None, meta_classes=True):
        """
        root : chemin du dossier contenant les PNG + LG
        transform : transform PyTorch (augmentations, ToTensor, Resize…)
        meta_classes : si True, map chaque label vers une méta-classe
        """
        self.root = root
        self.transform = transform
        self.meta_classes = meta_classes

        # liste des fichiers PNG / LG
        self.images = [f for f in os.listdir(root) if f.endswith(".png")]
        self.images.sort()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.root, img_name)

        lg_name = img_name.replace(".png", ".lg")
        lg_path = os.path.join(self.root, lg_name)

        image = Image.open(img_path).convert("RGB")

        boxes = []
        labels = []

        with open(lg_path, "r", encoding='utf-8', errors='ignore') as f:
            for line in f:
                parts = [p.strip() for p in line.strip().split(",") if p.strip() != ""]
                if len(parts) < 6:
                    # fallback to whitespace splitting if commas are not reliable
                    parts = [p.strip() for p in line.strip().split() if p.strip() != ""]

                if len(parts) < 6:
                    continue

                label = parts[1]
                try:
                    xmin_s, ymin_s, xmax_s, ymax_s = parts[-4:]
                    xmin = float(xmin_s)
                    ymin = float(ymin_s)
                    xmax = float(xmax_s)
                    ymax = float(ymax_s)
                except Exception:
                    continue

                if xmax <= xmin or ymax <= ymin:
                    warnings.warn(
                            f"Found invalid bbox in '{lg_path}': [xmin={xmin}, ymin={ymin}, xmax={xmax}, ymax={ymax}]. These boxes will be skipped.")
                    continue

                boxes.append([xmin, ymin, xmax, ymax])
                labels.append(self.map_label(label))

        # Convert to tensors; ensure correct shapes even when empty
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

        target = {"boxes": boxes, "labels": labels}

        if self.transform:
            image = self.transform(image)

        return image, target

    def map_label(self, label):
        raw = label.split("_")[0].strip()
        if raw.isalpha():
            return 0

        if raw.isdigit():
            return 1

        if raw in {"+", "-", "=", "/", "*", "×", "÷", "^"}:
            return 2
        return 3

    def raw_label_to_id(self, raw):
        if not hasattr(self, "raw_vocab"):
            self.raw_vocab = {}
        if raw not in self.raw_vocab:
            self.raw_vocab[raw] = len(self.raw_vocab)
        return self.raw_vocab[raw]

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

In [None]:
root = "datas/FullExpressions/CROHME2019_train_png/"      #clara
#root = "../../data/FullExpressions/CROHME2019_train_png/" #marie

In [42]:
dataset = CROHMEDataset(
    root=root,
    transform=transforms.ToTensor(),
    meta_classes=True
)

image, target = dataset[0]
print("Image : ", image.size())
print("Target : ", target)

dataset_len = len(dataset)
train_len = int(0.8 * dataset_len)
val_len = int(0.1 * dataset_len)
test_len = dataset_len - train_len - val_len

train, val, test = torch.utils.data.random_split(
    dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(42)
)
print(f"Dataset sizes -> total: {dataset_len}, train: {train_len}, val: {val_len}, test: {test_len}")

Image :  torch.Size([3, 119, 500])
Target :  {'boxes': tensor([[ 10.,  40.,  39., 108.],
        [270.,  46., 320.,  72.],
        [340.,  47., 371.,  71.],
        [467.,  10., 489.,  38.],
        [121.,  38., 166.,  63.],
        [226.,  21., 265.,  65.],
        [399.,  10., 452.,  76.]]), 'labels': tensor([0, 0, 2, 1, 2, 0, 0])}
Dataset sizes -> total: 9993, train: 7994, val: 999, test: 1000


## Functions for visualization and evaluation

In [43]:
def load_image(image_path):
    """Load an image from file."""
    image = Image.open(image_path).convert("RGB")
    return image

def prepare_image(image, transform=None):
    """Prepare the image for model input."""
    if transform:
        image = transform(image)
    return image.unsqueeze(0)  # Add batch dimension

def visualize_predictions(image, boxes, labels, scores, threshold=0.4):
    """Visualize the bounding boxes and labels on the image."""
    plt.figure(figsize=(12, 8))
    plt.imshow(image.permute(1, 2, 0).numpy())

    # Filter out boxes and labels below the threshold
    for box, label, score in zip(boxes, labels, scores):
        if score >= threshold:
            x_min, y_min, x_max, y_max = box
            plt.gca().add_patch(plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                                                fill=False, edgecolor='red', linewidth=3))
            plt.text(x_min, y_min, f'{label.item()}: {score:.2f}', fontsize=12, color='red')

    plt.axis('off')
    plt.show()

In [None]:
# IoU et mAP qui viennent d'Object_Segmentation

def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    """
    Calculates intersection over union

    Parameters:
        boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
        boxes_labels (tensor): Correct Labels of Boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)

    Returns:
        tensor: Intersection over union for all examples
    """

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    elif box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    # Need clamp(0) in case they do not intersect, then we want intersection to be 0
    intersection = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
    box1_area = (box1_x2 - box1_x1).abs() * (box1_y2 - box1_y1).abs()
    box2_area = (box2_x2 - box2_x1).abs() * (box2_y2 - box2_y1).abs()

    return intersection / (box1_area + box2_area - intersection + 1e-6)


def mean_average_precision(
    pred_boxes, true_boxes, iou_threshold=0.5, box_format="corners", num_classes=20
):
    """
    Calculates mean average precision

    Parameters:
        pred_boxes (list): list of lists containing all bboxes with each bboxes
        specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
        true_boxes (list): Similar as pred_boxes except all the correct ones
        iou_threshold (float): threshold where predicted bboxes is correct
        box_format (str): "midpoint" or "corners" used to specify bboxes
        num_classes (int): number of classes

    Returns:
        float: mAP value across all classes given a specific IoU threshold
    """

    average_precisions = []
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        # Collect detections and ground truths for class c
        for detection in pred_boxes:
            # detection[1] may be int or tensor-like
            if int(detection[1]) == c:
                detections.append(detection)

        for true_box in true_boxes:
            if int(true_box[1]) == c:
                ground_truths.append(true_box)

        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # sort by confidence score (index 2)
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # Ground truths for the same image index
            ground_truth_img = [bbox for bbox in ground_truths if bbox[0] == detection[0]]

            best_iou = 0.0
            best_gt_idx = -1

            for idx, gt in enumerate(ground_truth_img):
                # Compute IoU and convert to float for safe comparisons
                iou_tensor = intersection_over_union(
                    torch.tensor(detection[3:], dtype=torch.float32),
                    torch.tensor(gt[3:], dtype=torch.float32),
                    box_format=box_format,
                )
                iou = float(iou_tensor.item()) if torch.is_tensor(iou_tensor) else float(iou_tensor)

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                # only detect ground truth once
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)

        # Ensure tensors are float and prepend start points for integration
        precisions = torch.cat((torch.tensor([1.0]), precisions))
        recalls = torch.cat((torch.tensor([0.0]), recalls))

        # Numerical integration (area under precision-recall curve)
        ap = float(torch.trapz(precisions, recalls).item())
        average_precisions.append(ap)

    if len(average_precisions) == 0:
        return 0.0

    return sum(average_precisions) / len(average_precisions)


## Training loop

In [None]:
# Function for validation
def validate(model, val_loader):
    val_loss = 0.0
    with torch.no_grad():
        for images, targets in val_loader:
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item()

    return val_loss / len(val_loader)

In [15]:
# Hyperparameters
num_epochs = 15
learning_rate =0.0008
batch_size = 3
# Keep val_size if you want an absolute val count fallback, but we'll use dynamic splits
val_size = 10

val_err_array = np.array([])
train_err_array = np.array([])
nb_sample_array = np.array([])
train_loss_classifier_array = np.array([])
train_loss_objectness_array = np.array([])

# Early stopping parameters
patience =5
epochs_without_improvement = 0

# Use the Subset objects created earlier by random_split: `train`, `val`, `test`.
# If `train` or `val` don't exist yet (cell not executed), compute splits here as a fallback.
try:
    train_subset = train
    val_subset = val
except NameError:
    dataset_len = len(dataset)
    train_len = int(0.8 * dataset_len)
    val_len = int(0.1 * dataset_len)
    test_len = dataset_len - train_len - val_len
    train_subset, val_subset, _ = torch.utils.data.random_split(dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(42))

# Create DataLoaders for training and validation
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

# Load a pretrained Faster R-CNN model
#model = detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
#model = detection.ssd300_vgg16(weights="DEFAULT")
model = detection.fasterrcnn_mobilenet_v3_large_fpn(weights="DEFAULT")

# Set the requires_grad attribute of all the backbone parameters to False
for param in model.backbone.parameters():
    param.requires_grad = False
print("Backbone frozen. Only the RPN and heads will be trained.")

# Modify the model for the number of classes
num_classes = 5  # 20 classes + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

# Move the model to GPU if available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Set up the optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=learning_rate, momentum=0.9, weight_decay=0.0005)


# Training loop
best_val_loss = float('inf')  # Initialize best validation loss
for epoch in range(num_epochs):
    epoch_loss = 0.0
    epoch_loss_classifier = 0.0
    epoch_loss_objectness = 0.0
    model.train()  # Set the model to training mode
    nb_used_sample = 0 # Initialize the number of samples used in this epoch

    for images, targets in train_loader:
        # Move images and targets to the device (GPU or CPU)
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        loss_dict = model(images, targets)

        # Compute total loss
        losses = sum(loss for loss in loss_dict.values())

        # Backward pass
        losses.backward()

        # Update the weights
        optimizer.step()

        # Accumulate loss
        epoch_loss += losses.item()
        # Use .get to avoid KeyError if a particular loss term is missing
        epoch_loss_classifier += loss_dict.get('loss_classifier', torch.tensor(0.0)).item()
        epoch_loss_objectness += loss_dict.get('loss_objectness', torch.tensor(0.0)).item()
        nb_used_sample += len(images)

    # Calculate average training loss for the epoch
    train_err = epoch_loss / len(train_loader)
    train_loss_classifier = epoch_loss_classifier / len(train_loader) if len(train_loader) > 0 else 0.0
    train_loss_objectness = epoch_loss_objectness / len(train_loader) if len(train_loader) > 0 else 0.0

    # Print epoch loss
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {train_err:.4f}, Classifier Loss: {train_loss_classifier:.4f}, Objectness Loss: {train_loss_objectness:.4f}")

    # Validate after each epoch
    val_loss = validate(model, val_loader)
    print(f"Validation Loss: {val_loss:.4f}")
    train_err_array = np.append(train_err_array, train_err)
    val_err_array = np.append(val_err_array, val_loss)
    nb_sample_array = np.append(nb_sample_array, nb_used_sample)
    train_loss_classifier_array = np.append(train_loss_classifier_array, train_loss_classifier)
    train_loss_objectness_array = np.append(train_loss_objectness_array, train_loss_objectness)

    # Save the model weights if validation loss has improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'faster_rcnn_voc_best.pth')
        print(f"Model weights saved. New best validation loss: {best_val_loss:.4f}")
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= patience:
        print(f"Early stopping after {patience} epochs without improvement.")
        break

# Final message
print("Training complete.")

Backbone frozen. Only the RPN and heads will be trained.


KeyboardInterrupt: 