In [None]:
# Dependencies
import math
import cv2
import pandas as pd
import numpy as np
from datetime import datetime

from libs.libs import run_sql_query

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [2]:
def get_corner_points(rectangles, cells):
    
    # Push the 4 corners of each rectangle to points
    points = []
    for r in rectangles:
        id_ = r['id']
        x0, y0, w, h = r['x0'], r['y0'], r['w'], r['h']
        points.extend([
            [id_, x0, y0],                   # top-left corner
            [id_, x0 + w, y0],               # top-right corner
            [id_, x0 + w, y0 + h],           # bottom-right corner
            [id_, x0, y0 + h]                # bottom-left corner
        ])

    # Get the center of all points
    sumX, sumY = 0, 0
    for point in points:
        sumX += point[1]
        sumY += point[2]
    cx, cy = sumX / len(points), sumY / len(points)

    # Initialize the corners
    tl, tr, br, bl = {'id': None, 'x': float('inf'), 'y': float('inf')}, \
                    {'id': None, 'x': float('-inf'), 'y': float('inf')}, \
                    {'id': None, 'x': float('-inf'), 'y': float('-inf')}, \
                    {'id': None, 'x': float('inf'), 'y': float('-inf')}

    minDistTL, minDistTR, minDistBR, minDistBL = float('-inf'), float('-inf'), float('-inf'), float('-inf')

    # Iterate over points
    for point in points:
        id_, x, y = point
        dist = math.sqrt((cx - x) ** 2 + (cy - y) ** 2)

        # Check in which quadrant the point is
        if x < cx and y < cy and dist > minDistTL:
            tl = {'id': id_, 'x': x, 'y': y}
            minDistTL = dist
        elif x > cx and y < cy and dist > minDistTR:
            tr = {'id': id_, 'x': x, 'y': y}
            minDistTR = dist
        elif x > cx and y > cy and dist > minDistBR:
            br = {'id': id_, 'x': x, 'y': y}
            minDistBR = dist
        elif x < cx and y > cy and dist > minDistBL:
            bl = {'id': id_, 'x': x, 'y': y}
            minDistBL = dist

    # Get the corresponding cells
    cell_tl = next(cell for cell in cells if cell['rect_id'] == tl['id'])
    cell_tr = next(cell for cell in cells if cell['rect_id'] == tr['id'])
    cell_br = next(cell for cell in cells if cell['rect_id'] == br['id'])
    cell_bl = next(cell for cell in cells if cell['rect_id'] == bl['id'])

    # Get corners
    tl = {'x': cell_tl['tl_x'], 'y': cell_tl['tl_y']}
    tr = {'x': cell_tr['tr_x'], 'y': cell_tr['tr_y']}
    br = {'x': cell_br['br_x'], 'y': cell_br['br_y']}
    bl = {'x': cell_bl['bl_x'], 'y': cell_bl['bl_y']}

    # Return corners
    return [[tl['x'], tl['y']], [tr['x'], tr['y']], [br['x'], br['y']], [bl['x'], bl['y']]]

In [3]:
def reproject(src_points, dst_points, rectangles):
    
    # Create source and destination points matrices
    srcPoints = np.array(src_points, dtype=np.float32).reshape(4, 1, 2)
    dstPoints = np.array(dst_points, dtype=np.float32).reshape(4, 1, 2)
    
    # Find homography matrix    
    homographyMatrix = cv2.findHomography(srcPoints, dstPoints, cv2.RANSAC, 5.0)[0]

    # Project the points
    rectangles_projected = []
    for rectangle in rectangles:
        
        # unpack
        x0 = rectangle['x0']
        y0 = rectangle['y0']
        w = rectangle['w']
        h = rectangle['h']
        tl_x = x0
        tl_y = y0
        tr_x = x0 + w
        tr_y = y0
        br_x = x0 + w
        br_y = y0 + h
        bl_x = x0
        bl_y = y0 + h
        
        # Convert to array
        points = np.array([[tl_x, tl_y], [tr_x, tr_y], [br_x, br_y], [bl_x, bl_y]], dtype=np.float32).reshape(4, 1, 2)
        
        # Reproject
        transformedCorners = cv2.perspectiveTransform(points, homographyMatrix)
        transformedCorners = transformedCorners.reshape(4, 2)
        
        # Append reprojected rectangle to results
        rectangles_projected.append({
            "tl_x": transformedCorners[0][0],
            "tl_y": transformedCorners[0][1],
            "tr_x": transformedCorners[1][0],
            "tr_y": transformedCorners[1][1],
            "br_x": transformedCorners[2][0],
            "br_y": transformedCorners[2][1],
            "bl_x": transformedCorners[3][0],
            "bl_y": transformedCorners[3][1]
        })

    return rectangles_projected

### Images

In [4]:
query = """
SELECT 
    image.id AS image_id,
    DATE(image.created_at) AS created_at,
    image.table_template_id AS table_template_id,
    image.rotation AS rotation,
    JSON_AGG(JSON_BUILD_OBJECT(
        'rect_id', rect_id,
        'tl_x', tl_x,
        'tl_y', tl_y,
        'tr_x', tr_x,
        'tr_y', tr_y,
        'bl_x', bl_x,
        'bl_y', bl_y,
        'br_x', br_x,
        'br_y', br_y
    )) AS cells
FROM 
    image
LEFT JOIN
    cell
ON
    cell.image_id = image.id
WHERE
    (image.approved IS NULL OR image.approved = true) AND image.status != 0
GROUP BY 
    image.id, image.created_at, image.table_template_id, image.rotation
ORDER BY
    image.created_at DESC
"""

images = run_sql_query(query)
images = [ { 'image_id': im[0], 'created_at': im[1], 'table_template_id': im[2], 'rotation': im[3], 'cells': im[4] } for im in images ]

### Table Templates

In [5]:
query = """
SELECT 
    id,
    rectangles
FROM 
    table_template
"""

table_templates = run_sql_query(query)
table_templates_dict = {}
for table_template in table_templates:
    table_templates_dict[table_template[0]] = table_template[1]

### Set corners

In [6]:
for image in images:
    
    # unpack
    cells = image['cells']
    table_template_id = image['table_template_id']
    table_template = table_templates_dict[table_template_id]
    
    # get corner points
    # [[tl['x'], tl['y']], [tr['x'], tr['y']], [br['x'], br['y']], [bl['x'], bl['y']]]
    corners = get_corner_points(table_template, cells)
    
    # compute approximate width & height
    tl_x = corners[0][0]
    tl_y = corners[0][1]
    tr_x = corners[1][0]
    tr_y = corners[1][1]
    br_x = corners[2][0]
    br_y = corners[2][1]
    bl_x = corners[3][0]
    bl_y = corners[3][1]
    width = ((tr_x - tl_x) + (br_x - bl_x)) / 2.0
    height = ((bl_y - tl_y) + (br_y - tr_y)) / 2.0
    
    # set
    image['corners'] = corners
    image['width'] = width
    image['height'] = height   

### Train a Linear Model

In [7]:
# Parameters
MIN_CREATED_AT = datetime.strptime('2023-08-01', '%Y-%m-%d').date()
MAX_CREATED_AT = datetime.strptime('2024-08-13', '%Y-%m-%d').date()

In [None]:
images_training = [ im for im in images if im['created_at'] >= MIN_CREATED_AT and im['created_at'] <= MAX_CREATED_AT]
images_evaluation = [ im for im in images if im['created_at'] > MAX_CREATED_AT]

# format
data_training = pd.DataFrame([ { 'width': im['width'], 'height': im['height'], 'table_template_id': im['table_template_id'] } for im in images_training ])
data_evaluation = pd.DataFrame([ { 'image_id': im['image_id'], 'width': im['width'], 'height': im['height'], 'table_template_id': im['table_template_id'] } for im in images_evaluation ])

# log
print(f'Number of images training: {len(images_training)}')
print(f'Number of images evaluation: {len(images_evaluation)}')

In [None]:
# Splitting data into training and test set (80% train, 20% test)
X = data_training[['width', 'height']]
y = data_training['table_template_id']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

# Train the Random Forest classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Predict using the trained model
y_pred = clf.predict(X_test)

# Calculate and print the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy*100:.2f}%")

# Calculate and print the mean squared error
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")

In [10]:
# Predict on evaluation
X_eval = data_evaluation[['width', 'height']]

# Predict using the trained model
y_eval = clf.predict(X_eval)

# Set
data_evaluation['table_template_id_pred'] = y_eval

# Set in original dataset
for prediction in data_evaluation.to_dict(orient='records'):
    
    # unpack
    image_id = prediction['image_id']
    table_template_id_pred = prediction['table_template_id_pred']
    
    # set
    for i, image in enumerate(images_evaluation):
        if image['image_id'] != image_id: continue
        images_evaluation[i]['table_template_id_pred'] = table_template_id_pred

### Go through and project

In [11]:
for image in images_evaluation:
    
    # unpack
    image_id = image['image_id']
    corners = image['corners']
    cells = image['cells']
    table_template_id_orig = image['table_template_id']
    table_template_id_pred = image['table_template_id_pred']
    rectangles = table_templates_dict[table_template_id_pred]
    
    # obtain the min_x, max_x, min_y, max_y of the table template
    min_x = float('inf')
    max_x = float('-inf')
    min_y = float('inf')
    max_y = float('-inf')
    for r in rectangles:
        x0, y0, w, h = r['x0'], r['y0'], r['w'], r['h']
        min_x = min(min_x, x0)
        max_x = max(max_x, x0 + w)
        min_y = min(min_y, y0)
        max_y = max(max_y, y0 + h)
    
    # create source points
    # tl, tr, br, bl
    src_points = [
        [min_x, min_y],
        [max_x, min_y],
        [max_x, max_y],
        [min_x, max_y]
    ]

    # reproject
    rectangles_projected = reproject(src_points, corners, rectangles)

    # update cells
    for i, r_proj in enumerate(rectangles_projected):

        # get corresponding rectangle
        rect = rectangles[i]

        # update
        r_proj['id'] = rect['id']
        r_proj['data_type'] = rect['data_type']
        r_proj['opts'] = rect['opts']

    # set
    image['rectangles'] = rectangles_projected

### Preview

In [None]:
for j, image in enumerate(images_evaluation):
    
    # unpack
    image_id = image['image_id']
    rotation = image['rotation']
    corners = image['corners']
    cells = image['cells']
    rectangles = image['rectangles']

    # Load image using PIL
    image_path = f"/Users/admin/3-CAI/images/{image_id}/{image_id}.jpg"
    img = Image.open(image_path)
    img_array = np.array(img)
    
    # rotate
    nbr_rotations = int((int(rotation) / 90.0))
    for i in range(0, nbr_rotations):
        img_array = np.rot90(img_array)

    # Create a plot
    fig, ax = plt.subplots()
    ax.imshow(img_array)

    # Draw points on the image
    for x, y in corners:
        ax.plot(x, y, 'ro')  # 'ro' means red color, circle marker
        
    # Draw rectangles on the image
    for r in rectangles:
        
        # unpack
        tl_x = r['tl_x']
        tl_y = r['tl_y']
        tr_x = r['tr_x']
        tr_y = r['tr_y']
        br_x = r['br_x']
        br_y = r['br_y']
        bl_x = r['bl_x']
        bl_y = r['bl_y']

        # draw (make line thin)  
        ax.plot([tl_x, tr_x], [tl_y, tr_y], 'r-', linewidth=0.5)
        ax.plot([tr_x, br_x], [tr_y, br_y], 'r-', linewidth=0.5)
        ax.plot([br_x, bl_x], [br_y, bl_y], 'r-', linewidth=0.5)
        ax.plot([bl_x, tl_x], [bl_y, tl_y], 'r-', linewidth=0.5)   

    plt.show()
    
    break

### Push

In [14]:
for image in images_evaluation:
    
    # unpack
    image_id = image['image_id']
    cells = image['cells']
    rectangles = image['rectangles']
    image['queries'] = []
    
    # go through rectangles
    for r in rectangles:
        
        # unpack
        rect_id = r['id']
        opts = r['opts']
        data_type = r['data_type']
        tl_x = round(r['tl_x'])
        tl_y = round(r['tl_y'])
        tr_x = round(r['tr_x'])
        tr_y = round(r['tr_y'])
        br_x = round(r['br_x'])
        br_y = round(r['br_y'])
        bl_x = round(r['bl_x'])
        bl_y = round(r['bl_y'])
        
        # compose
        sql_query = f"INSERT INTO cell (image_id, rect_id, opts, data_type, tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y) VALUES ('{image_id}', {rect_id}, '{opts}', '{data_type}', {tl_x}, {tl_y}, {tr_x}, {tr_y}, {br_x}, {br_y}, {bl_x}, {bl_y})"
        
        # append
        image['queries'].append(sql_query)

In [15]:
for image in images_evaluation:
    
    # unpack
    queries = image['queries']
    
    # compose
    sql_query_delete = f"DELETE FROM cell WHERE image_id = '{image['image_id']}'"
    sql_query_update_table_id = f"UPDATE image SET table_template_id = {image['table_template_id_pred']} WHERE id = '{image['image_id']}'"
    sql_query_update_status = f"UPDATE image SET status = 5 WHERE id = '{image['image_id']}'"
    
    # run
    run_sql_query(sql_query_delete)
    run_sql_query(sql_query_update_table_id)
    run_sql_query(sql_query_update_status)

    # run queries
    for query in queries:
        run_sql_query(query)    