In [1]:
import os
from PIL import Image
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report
from random import randint
import cv2

In [None]:
main_folder_path = 'Data\train' #path to train dataset
image_list = []
labels_list = []
for root, dirs, files in os.walk(main_folder_path):
    for filename in files:
        if filename.endswith('.jpg') or filename.endswith('.jpeg'):
            try:
                img_path = os.path.join(root, filename)
                label = os.path.basename(root)  
                labels_list.append(label)
                img = Image.open(img_path)
                img_array = np.array(img)
                image_list.append(img_array)
            except Exception as e:
                print(f"Error loading {filename}: {e}")

def process_images(image_list, label_list):
    processed_images = []
    processed_labels = []
    for img, label in zip(image_list, label_list):
        img_pil = Image.fromarray(img).convert('L')
        img_resized = img_pil.resize((128, 128))
        processed_images.append(np.array(img_resized)) 
        processed_labels.append(label)
    stacked_images = np.stack(processed_images, axis=0)
    return stacked_images, processed_labels

images_list, labels = process_images(image_list, labels_list)

In [3]:
def get_random_image(symbol,image_list=images_list,labels_list=labels_list):
    symbol_map = {
    '.': 'decimal',
    '/': 'div',
    '8': 'eight',
    '=': 'equal',
    '5': 'five',
    '4': 'four',
    '-': 'minus',
    '9': 'nine',
    '1': 'one',
    '+': 'plus cleaned',
    '7': 'seven',
    '6': 'six',
    '3': 'three',
    '*': 'times',
    '2': 'two',
    '0': 'zero'
    }
    mask=np.all(np.array([labels_list])==symbol_map[symbol],axis=0)
    return image_list[mask][randint(0,sum(mask)-1)]

def extract_mser_segments(image, resize_shape=(128, 128), padding_factor=2.2, extra_padding=50,show_plots:bool=False): # 2.2 50
    mser = cv2.MSER_create()
    regions, _ = mser.detectRegions(image)
    image_color = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    bounding_boxes = []
    
    for p in regions:
        x, y, w, h = cv2.boundingRect(p.reshape(-1, 1, 2))
        if w > 5 and h > 5:
            bounding_boxes.append((x, y, w, h))

    def merge_bounding_boxes(boxes, overlap_thresh=0.3, vertical_merge_thresh=20):
        merged_boxes = []
        used = [False] * len(boxes)

        for i in range(len(boxes)):
            if used[i]:
                continue

            x1, y1, w1, h1 = boxes[i]
            merged_box = [x1, y1, x1 + w1, y1 + h1]

            for j in range(i + 1, len(boxes)):
                if used[j]:
                    continue

                x2, y2, w2, h2 = boxes[j]

                if (x1 < x2 + w2 and x1 + w1 > x2 and y1 < y2 + h2 and y1 + h1 > y2) or \
                   (abs(x1 - x2) < 10 and abs(w1 - w2) < 10 and abs(y1 + h1 - y2) < vertical_merge_thresh):
                    merged_box[0] = min(merged_box[0], x2)
                    merged_box[1] = min(merged_box[1], y2)
                    merged_box[2] = max(merged_box[2], x2 + w2)
                    merged_box[3] = max(merged_box[3], y2 + h2)
                    used[j] = True

            merged_boxes.append((merged_box[0], merged_box[1], merged_box[2] - merged_box[0], merged_box[3] - merged_box[1]))
            used[i] = True

        return merged_boxes

    merged_boxes = merge_bounding_boxes(bounding_boxes)
    merged_boxes = sorted(merged_boxes, key=lambda b: b[0])
    resized_segments = []

    for index, (x, y, w, h) in enumerate(merged_boxes, start=1):
        cv2.rectangle(image_color, (x, y), (x + w, y + h), (0, 255, 0), 2)
        segment = image[y:y+h, x:x+w]
        height, width = segment.shape
        
        max_dim = int(max(height, width) * padding_factor)
        padded_segment = np.full((max_dim, max_dim), 255, dtype=np.uint8)
        
        if height > width:
            new_width = int((max_dim / height) * width)
            resized_segment = cv2.resize(segment, (new_width, max_dim))
            x_offset = (max_dim - new_width) // 2
            padded_segment[:, x_offset:x_offset + new_width] = resized_segment
        else:
            new_height = int((max_dim / width) * height)
            resized_segment = cv2.resize(segment, (max_dim, new_height))
            y_offset = (max_dim - new_height) // 2
            padded_segment[y_offset:y_offset + new_height, :] = resized_segment    
        
        extra_padded_segment = np.full((max_dim + 2 * extra_padding, max_dim + 2 * extra_padding), 255, dtype=np.uint8)
        extra_padded_segment[extra_padding:max_dim + extra_padding, extra_padding:max_dim + extra_padding] = padded_segment
    
        final_segment = cv2.resize(extra_padded_segment, resize_shape)
        resized_segments.append(final_segment)
    return resized_segments

def threshold_image(image: np.ndarray):
    return (image > 240).astype('uint8') * 255

def cutspace(image: np.ndarray, left: bool = False, right: bool = False):
    if not left and not right:
        return image
    
    rows, cols = image.shape
    left_cutoff = 0
    right_cutoff = cols

    if left:
        for i in range(cols):
            if not np.all(image[:, i] == 255):
                left_cutoff = i//2
                break

    if right:
        for i in range(cols - 1, -1, -1):
            if not np.all(image[:, i] == 255):
                right_cutoff = (i + 1+cols)//2
                break

    return image[:, left_cutoff:right_cutoff]

def get_eqn(eqn: str):
    eqn_images = []
    for i, symbol in enumerate(eqn):
        image = threshold_image(get_random_image(symbol))
        if symbol.isdigit():
            left = i != 0 and eqn[i - 1].isdigit()
            right = i != len(eqn) - 1 and eqn[i + 1].isdigit()
            image = cutspace(image, left, right)
        eqn_images.append(image)
    
    output = np.hstack(eqn_images)
    return output

import re

def parse_equation(equation_str):
    equation_str = equation_str.replace(" ", "")
    pattern = r"(\d+\.?\d*)([\+\-\*/])(\d+\.?\d*)"
    match = re.match(pattern, equation_str)
    if not match:
        raise ValueError("Invalid equation format. Expected format like '12.5 + 3.4 ='")
    num1 = float(match.group(1))  
    operator = match.group(2)
    num2 = float(match.group(3)) 

    if operator == "+":
        result = num1 + num2
    elif operator == "-":
        result = num1 - num2
    elif operator == "*":
        result = num1 * num2
    elif operator == "/":
        if num2 == 0:
            raise ZeroDivisionError("Division by zero is not allowed.")
        result = num1 / num2
    else:
        raise ValueError("Unsupported operator. Supported operators are: +, -, *, /")

    return int(result) if result.is_integer() else result

def get_number(n:int):
    return randint(10**(n-1),10**(n)-1)
def generate_equation():
    operator_choices=["+","-","*"]
    order_1=randint(1,3)
    order_2=randint(1,3)
    num1=str(get_number(order_1))
    num2=str(get_number(order_2))
    return num1+operator_choices[randint(0,2)]+num2

def generate_final():
    eqn=generate_equation()
    output=parse_equation(eqn)
    image=get_eqn(eqn)
    return image,output

def get_image_and_labels():
    eqn=generate_equation()
    image=get_eqn(eqn)
    segments=extract_mser_segments(image)
    if len(segments)==len(eqn):
        return segments,list(eqn)
    else:
        return None,None

## Augmenting Data
* If the model is directly applied to predict segmented output it will fail as the segmented output does not have any padding
* To prevent this problem, random equations are generated, corresponding images are picked and horizonatally stacked with reduced spacing between digits to simulate handwritting
* Then created a dataset of approximately 10,000 which are images taken after segmenting and padding the randomly generated equations

In [4]:
new_images_list=[]
new_labels_list=[]
while len(new_images_list)<10000:
    img_list,lbl_list=get_image_and_labels()
    if img_list:
        new_images_list.extend(img_list)
        new_labels_list.extend(lbl_list)

In [5]:
images_list=new_images_list
labels=new_labels_list

In [6]:
num_classes = 16
image_array = torch.tensor(images_list, dtype=torch.float32)
image_tensor = image_array.unsqueeze(1) / 255.0  

label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(labels)  
labels_tensor = torch.tensor(labels_encoded, dtype=torch.long)

class CustomDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images) 

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

  image_array = torch.tensor(images_list, dtype=torch.float32)


## Initializing the model
* added regularization as previous simple CNN models ended up overfitting within 10 epochs with the same hyperparameters

In [8]:
class RegularizedCNN(nn.Module):
    def __init__(self, num_classes=16):
        super(RegularizedCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(p=0.5)  
        self.fc1 = nn.Linear(32 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.batch_norm1 = nn.BatchNorm2d(16)  
        self.batch_norm2 = nn.BatchNorm2d(32)

    def forward(self, x):
        x = self.pool(F.relu(self.batch_norm1(self.conv1(x))))
        x = self.pool(F.relu(self.batch_norm2(self.conv2(x))))
        x = x.view(-1, 32 * 32 * 32)  
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  
        x = self.fc2(x)
        return F.log_softmax(x, dim=1) 


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = RegularizedCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [9]:
dataset = CustomDataset(image_tensor, labels_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## Training loop

In [10]:
num_epochs = 10  
for epoch in range(num_epochs):
    model.train()  
    epoch_loss = 0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = epoch_loss / len(dataloader)
    accuracy = 100 * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

Epoch [1/10], Loss: 1.0784, Accuracy: 66.86%
Epoch [2/10], Loss: 0.4854, Accuracy: 83.31%
Epoch [3/10], Loss: 0.3769, Accuracy: 86.71%
Epoch [4/10], Loss: 0.3015, Accuracy: 89.09%
Epoch [5/10], Loss: 0.2880, Accuracy: 90.32%
Epoch [6/10], Loss: 0.2421, Accuracy: 91.40%
Epoch [7/10], Loss: 0.2141, Accuracy: 92.32%
Epoch [8/10], Loss: 0.1933, Accuracy: 93.11%
Epoch [9/10], Loss: 0.1827, Accuracy: 93.40%
Epoch [10/10], Loss: 0.1728, Accuracy: 93.81%


## Saving model

In [None]:
model_save_path = 'reg_cnn_model.pth'
torch.save(model.state_dict(), model_save_path)
with open("LabelEncoder.pkl", 'wb') as f:
    pkl.dump(label_encoder, f)
print(f"Model saved to {model_save_path}")

## model evaluation

In [11]:
model.eval()  
all_labels = []
all_predictions = []
with torch.no_grad():
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

print("\nClassification Report:")
print(classification_report(all_labels, all_predictions, target_names=label_encoder.classes_))

accuracy = accuracy_score(all_labels, all_predictions)
print(f"Overall Accuracy: {accuracy:.2f}")


Classification Report:
              precision    recall  f1-score   support

           *       0.99      1.00      1.00       680
           +       1.00      1.00      1.00       616
           -       1.00      1.00      1.00       696
           0       1.00      0.98      0.99       390
           1       1.00      0.99      1.00       827
           2       0.99      1.00      0.99       865
           3       1.00      1.00      1.00       848
           4       1.00      1.00      1.00       883
           5       1.00      1.00      1.00       830
           6       0.99      1.00      0.99       836
           7       0.99      1.00      1.00       827
           8       1.00      0.99      1.00       838
           9       1.00      1.00      1.00       867

    accuracy                           1.00     10003
   macro avg       1.00      1.00      1.00     10003
weighted avg       1.00      1.00      1.00     10003

Overall Accuracy: 1.00
