# It3 - Creating the numerical dictionary V2

## Importing libraries

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt


## Functions to load the images and the templates (masks)

In [None]:
def load_images_from_folder(folder_path):
    images = []
    image_ids = []

    # Iterate through all files in the given folder
    for filename in os.listdir(folder_path):
        if filename.endswith(".JPG") or filename.endswith(".png"):
            # Load the image using OpenCV
            image_path = os.path.join(folder_path, filename)
            image = cv2.imread(image_path)

            # Extract the image ID from the filename
            image_id = os.path.splitext(filename)[0]  # Remove file extension to get the ID

            # Append the image and its ID to the respective lists
            images.append(image)
            image_ids.append(image_id)

    return images, image_ids


In [None]:
def load_templates(template_folder_path):
    templates = {}

    for subfolder in ['cell', 'row', 'column']:
        subfolder_path = os.path.join(template_folder_path, subfolder)
        template_list = []

        # Iterate through all files in the subfolder
        for filename in os.listdir(subfolder_path):
            if filename.endswith(".jpg") or filename.endswith(".png"):
                # Load the template using OpenCV
                template_path = os.path.join(subfolder_path, filename)
                template = cv2.imread(template_path, cv2.IMREAD_GRAYSCALE)

                # Append the template to the list
                template_list.append(template)

        templates[subfolder] = template_list

    return templates

## Creating the Digit Object

In [None]:
class DigitObject:
    def __init__(self, digit_image, cell_number, width, length, x_min, x_max, y_min, y_max, prediction=None):
        self.digit_image = digit_image
        self.cell_number = cell_number
        self.width = width
        self.length = length
        self.x_min = x_min
        self.x_max = x_max
        self.y_min = y_min
        self.y_max = y_max
        self.prediction = prediction

    def to_dict(self):
        return {
            'digit_image': self.digit_image,
            'cell_number': self.cell_number,
            'width': self.width,
            'length': self.length,
            'x_min': self.x_min,
            'x_max': self.x_max,
            'y_min': self.y_min,
            'y_max': self.y_max,
            'prediction': self.prediction
        }



## Functions for image cropping and detection (Cells and Digits)

### Cells

In [None]:
def extract_cells_with_improved_sorting(image, column_template, row_template, cell_template, columns_of_interest=['A', 'C', 'D', 'E']):
    """
    Extract cells from the image using the column, row, and cell masks and organize them by column and row.
    """
    # Step 1: Detect Columns
    column_contours, _ = cv2.findContours(column_template, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Sort columns using centroid (left to right)
    column_contours = sorted(column_contours, key=lambda c: cv2.boundingRect(c)[0])

    # Map column names to indices (A, B, C, D, E)
    column_names = ['A', 'B', 'C', 'D', 'E']
    column_bboxes = {}

    for idx, contour in enumerate(column_contours):
        if idx < len(column_names):
            col_name = column_names[idx]
            if col_name in columns_of_interest:
                x, y, w, h = cv2.boundingRect(contour)
                column_bboxes[col_name] = (x, y, w, h)

    # Step 2: Detect Rows
    row_contours, _ = cv2.findContours(row_template, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    row_contours = sorted(row_contours, key=lambda c: cv2.boundingRect(c)[1])  # Sort rows top-to-bottom

    # Get bounding boxes for rows
    row_bboxes = []
    for contour in row_contours:
        x, y, w, h = cv2.boundingRect(contour)
        row_bboxes.append((x, y, w, h))

    # Step 3: Detect Cells Using Cell Mask
    image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
    _, binary_image = cv2.threshold(image_gray, 127, 255, cv2.THRESH_BINARY)
    _, binary_cell_template = cv2.threshold(cell_template, 127, 255, cv2.THRESH_BINARY)
    cell_contours, _ = cv2.findContours(binary_cell_template, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Sort cell contours by row first and then by column to maintain proper order
    cell_contours = sorted(cell_contours, key=lambda c: (cv2.boundingRect(c)[1], cv2.boundingRect(c)[0]))

    cells = {}

    # Step 4: Assign Cells to Columns and Rows
    for contour in cell_contours:
        x, y, w, h = cv2.boundingRect(contour)
        cropped_cell = binary_image[y:y + h, x:x + w]

        # Find the corresponding column by checking if the centroid lies within a column bounding box
        centroid_x = x + w // 2
        column_assigned = None
        for col_name, (col_x, col_y, col_w, col_h) in column_bboxes.items():
            if col_x <= centroid_x <= col_x + col_w:
                column_assigned = col_name
                break

        # Find the corresponding row by checking if the centroid lies within a row bounding box
        centroid_y = y + h // 2
        row_assigned = None
        for row_idx, (row_x, row_y, row_w, row_h) in enumerate(row_bboxes):
            if row_y <= centroid_y <= row_y + row_h:
                row_assigned = row_idx + 1
                break

        # Assign the cell if both column and row are determined
        if column_assigned is not None and row_assigned is not None:
            cell_number = f"{column_assigned}{row_assigned}"

            if column_assigned not in cells:
                cells[column_assigned] = []

            cells[column_assigned].append((cell_number, cropped_cell))

            # Visualize each cell (optional)
            # plt.figure(figsize=(5, 5))
            # plt.imshow(cropped_cell, cmap='gray')
            # plt.title(f'Cropped Cell {cell_number}')
            # plt.show()

    return cells


### Digits

In [None]:
def extract_digits_from_cells(cells, output_folder):
    """
    Extract digits from each cell, create DigitObject instances, and save them in a dictionary.
    
    Args:
    - cells: Dictionary containing cells organized by columns.
    - output_folder: Path to the output folder to save digit images.
    
    Returns:
    - digit_objects: Dictionary containing DigitObject instances categorized by column and cell.
    """
    # Ensure the output directory exists
    os.makedirs(output_folder, exist_ok=True)

    # Dictionary to store digit objects
    digit_objects = {}

    # Iterate through each column and cell
    for col_name, cell_list in cells.items():
        for cell_number, cell in cell_list:
            # Convert cell to grayscale if needed and apply adaptive threshold
            cell_gray = cv2.cvtColor(cell, cv2.COLOR_BGR2GRAY) if len(cell.shape) == 3 else cell

            # Use adaptive threshold to create a binary image
            binary_cell = cv2.adaptiveThreshold(
                cell_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 10
            )

            # Apply morphological operations to clean up the image
            kernel = np.ones((3, 3), np.uint8)
            binary_cell = cv2.morphologyEx(binary_cell, cv2.MORPH_CLOSE, kernel)  # Close gaps
            binary_cell = cv2.dilate(binary_cell, kernel, iterations=1)  # Dilate to separate digits

            # Find contours to detect digits
            digit_contours, _ = cv2.findContours(binary_cell, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            # Sort contours from left to right
            digit_contours = sorted(digit_contours, key=lambda c: cv2.boundingRect(c)[0])

            # Iterate over each contour to extract digits
            for digit_idx, contour in enumerate(digit_contours):
                x, y, w, h = cv2.boundingRect(contour)

                # Apply filtering to avoid very small or large regions being mistaken as digits
                cell_height, cell_width = cell.shape[:2]
                if 0.25 * cell_height < h < cell_height and 0.05 * cell_width < w < cell_width:
                    # Expand the bounding box slightly to avoid cutting off parts of digits
                    padding_x = int(0.1 * w)  # 10% of width as padding
                    padding_y = int(0.1 * h)  # 10% of height as padding
                    x_min = max(0, x - padding_x)
                    y_min = max(0, y - padding_y)
                    x_max = min(cell_width, x + w + padding_x)
                    y_max = min(cell_height, y + h + padding_y)

                    # Crop the digit from the binary cell image
                    digit = binary_cell[y_min:y_max, x_min:x_max]

                    # Resize to 28x28 for consistency
                    digit_resized = cv2.resize(digit, (28, 28), interpolation=cv2.INTER_AREA)

                    # Create a DigitObject and populate with relevant details
                    digit_obj = DigitObject(
                        digit_image=digit_resized,
                        cell_number=cell_number,
                        width=28,
                        length=28,
                        x_min=x_min,
                        x_max=x_max,
                        y_min=y_min,
                        y_max=y_max,
                        prediction=None  # Placeholder for prediction, can be updated after model inference
                    )

                    # Store the digit object in a dictionary
                    if col_name not in digit_objects:
                        digit_objects[col_name] = []
                    digit_objects[col_name].append(digit_obj)

                    # Save the digit image to the output folder
                    digit_filename = f"{cell_number}_digit{digit_idx + 1}.png"
                    digit_path = os.path.join(output_folder, digit_filename)
                    cv2.imwrite(digit_path, digit_resized)

                    # Optional: Visualize each digit
                    # plt.figure(figsize=(2, 2))
                    # plt.imshow(digit_resized, cmap='gray')
                    # plt.title(f"Digit from {cell_number}, Digit {digit_idx + 1}")
                    # plt.show()

    return digit_objects


## Running everything

### Loading the images

In [None]:
folder_path = "../data/sub60Cropped"
images, image_ids = load_images_from_folder(folder_path)

# Example of printing loaded images and their IDs
print(f"Loaded {len(images)} images")
# for img_id in image_ids:
#     print(f"Loaded Image ID: {img_id}")
    

### Loading the templates (masks)

In [None]:
template_folder_path = '../data/Cropped Masks'
templates = load_templates(template_folder_path)

### Extracting the digits using the masks

In [None]:
output_folder = './Data/It3/digits'
all_digit_objects = {}
# min(60, len(images))
for i in range(min(60, len(images))):
    image = images[i]
    image_id = image_ids[i]

    # Load the corresponding templates for each image
    column_template = templates['column'][i] if i < len(templates['column']) else templates['column'][-1]
    row_template = templates['row'][i] if i < len(templates['row']) else templates['row'][-1]
    cell_template = templates['cell'][i] if i < len(templates['cell']) else templates['cell'][-1]

    # plt.figure(figsize=(5, 5))
    # plt.imshow(image, cmap='gray')
    # plt.title(f'image {image_id}')
    # plt.show()
    # 
    # plt.figure(figsize=(5, 5))
    # plt.imshow(column_template, cmap='gray')
    # plt.title(f'column_template')
    # plt.show()
    # 
    # plt.figure(figsize=(5, 5))
    # plt.imshow(row_template, cmap='gray')
    # plt.title(f'row_template')
    # plt.show()
    # 
    # plt.figure(figsize=(5, 5))
    # plt.imshow(cell_template, cmap='gray')
    # plt.title(f'cell_template')
    # plt.show()


    # Step 1: Detect all columns (A, B, C, D, E) from the image
    # columns = detect_columns(image, column_template)

    # Step 2: Detect all rows from the image
    # rows = detect_rows(image, row_template)

    # Step 3: Extract cells using the detected columns and rows
    cells = extract_cells_with_improved_sorting(image, column_template, row_template, cell_template)

    # Step 4: Extract digits from the cells and create digit objects
    digit_objects = extract_digits_from_cells(cells, output_folder)

    # Store the digit objects for each image in a larger dictionary
    all_digit_objects[image_id] = digit_objects


### Printing the dictionary for checkups

In [None]:
# for image_id, columns in all_digit_objects.items():
#     print(f"Image ID: {image_id}")
#     for column_name, digit_list in columns.items():
#         print(f"  Column: {column_name}")
#         for digit_obj in digit_list:
#             # Print each attribute of the DigitObject instance
#             print(f"    Cell Number: {digit_obj.cell_number}")
#             print(f"    Width: {digit_obj.width}")
#             print(f"    Length: {digit_obj.length}")
#             print(f"    X Min: {digit_obj.x_min}")
#             print(f"    X Max: {digit_obj.x_max}")
#             print(f"    Y Min: {digit_obj.y_min}")
#             print(f"    Y Max: {digit_obj.y_max}")
#             print(f"    Prediction: {digit_obj.prediction}")
# 
#             # If you want to display the digit image
#             plt.figure(figsize=(2, 2))
#             plt.imshow(digit_obj.digit_image, cmap='gray')  # Assuming you saved the cropped image in `digit_obj.image`
#             plt.title(f"Digit in {digit_obj.cell_number}")
#             plt.show()


In [None]:
# Check the total number of images processed
num_images_processed = len(all_digit_objects)
print(f"Number of images processed: {num_images_processed}")

# Verify that we have data for each of the 60 images
if num_images_processed == 60:
    print("All 60 images were processed successfully.")
else:
    print(f"Warning: Only {num_images_processed} images were processed out of 60.")

# Check the length of digit objects for each image
for image_id, columns in all_digit_objects.items():
    total_digits = sum(len(digits) for digits in columns.values())
    print(f"Image ID: {image_id}, Total Digits Extracted: {total_digits}")

# Alternatively, you can count the total number of digit objects across all images
total_digits_all_images = sum(
    len(digit_list) for columns in all_digit_objects.values() for digit_list in columns.values()
)
print(f"Total number of digit objects extracted from all images: {total_digits_all_images}")


### Saving the dictionary

In [None]:
import os
import json
import cv2

output_folder = 'Data/It3/digit_images'
os.makedirs(output_folder, exist_ok=True)

# Dictionary to store the metadata of all_digit_objects without the images
exportable_metadata = {}

for image_id, columns in all_digit_objects.items():
    exportable_metadata[image_id] = {}

    for col_name, digit_list in columns.items():
        exportable_metadata[image_id][col_name] = []

        for digit_idx, digit_obj in enumerate(digit_list):
            # Save the image as a separate file
            image_filename = f"{image_id}_{col_name}_digit{digit_idx + 1}.png"
            image_path = os.path.join(output_folder, image_filename)
            cv2.imwrite(image_path, digit_obj.digit_image)

            # Prepare metadata without the image itself
            metadata = {
                'cell_number': digit_obj.cell_number,
                'width': digit_obj.width,
                'length': digit_obj.length,
                'x_min': digit_obj.x_min,
                'x_max': digit_obj.x_max,
                'y_min': digit_obj.y_min,
                'y_max': digit_obj.y_max,
                'prediction': digit_obj.prediction,
                'image_path': image_path  # Store the path to the image
            }

            exportable_metadata[image_id][col_name].append(metadata)

# Save the metadata dictionary as a JSON file
with open('Data/It3/all_digit_objects_metadata.json', 'w') as json_file:
    json.dump(exportable_metadata, json_file, indent=4)

print("Dictionary and images saved successfully.")
