In [None]:
import torch

from typing import List

from my_code import BoundingBox

import gradio as gr
import numpy as np

from pathlib import Path
from my_code import WordDetectorNet
import cv2


from my_code import normalize_image_transform
from my_code import ImageDimensions
from my_code import IAM_Dataset
from my_code import custom_collate_fn

from torch.utils.data import Subset
from torch.utils.data import DataLoader

from my_code import MapOrdering

from my_code import decode, fg_by_cc, BoundingBox, cluster_aabbs, draw_bboxes_on_image


def run_image_through_network(
        image_grayscale: np.ndarray,
        model_path: Path=Path('best_model.pth'),
    ) -> List[BoundingBox]:

    # ================
    # Configure system
    # ================

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ==========
    # Load model
    # ==========

    model = WordDetectorNet()  # instantiate your model
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # ==============
    # Pre processing
    # ==============

    image_gray_rescaled = cv2.resize(image_grayscale, WordDetectorNet.input_size)

    image_grayscale_transformed, _ = normalize_image_transform(image_gray_rescaled, None) # Only works w/ current transformation setup

    image_grayscale_transformed = image_grayscale_transformed.astype(np.float32)
    
    image_grayscale_transformed = torch.from_numpy(image_grayscale_transformed[None, None, :, :]).to(device)

    # =========
    # Inference
    # =========

    with torch.no_grad():
        output_image = model(image_grayscale_transformed, apply_softmax=True)

    assert output_image[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND+1, :, :].min() >= 0.0
    assert output_image[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND+1, :, :].max() <= 1.0

    output_image = output_image.to('cpu').numpy()

    output_image = output_image[0, :, :, :]

    # ===============
    # Post processing
    # ===============

    decoded_aabbs = decode(
        output_image,
        scale=WordDetectorNet.input_size[0] / WordDetectorNet.output_size[0],
        comp_fg=fg_by_cc(thres=0.5, max_num=1000),
    )
    model_input_image = image_grayscale_transformed[0, 0, :, :].to('cpu').numpy()
    h, w = model_input_image.shape
    aabbs = [aabb.clip(BoundingBox(0, 0, w - 1, h - 1)) for aabb in decoded_aabbs]  # bounding box must be inside input img
    clustered_aabbs = cluster_aabbs(aabbs)

    return {
        'aabbs': clustered_aabbs,
        'model_input_image': model_input_image,
    }

In [None]:
example_image_path = Path('cvl.jpg')

In [None]:
image = cv2.imread(str(example_image_path))

In [None]:
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
axes[0].set_title(f'Original Image @ {cv2.cvtColor(image, cv2.COLOR_BGR2RGB).shape}')
axes[0].axis('off')

axes[1].imshow(gray_image, cmap='gray')
axes[1].set_title(f'Grayscale Image @ {gray_image.shape}')
axes[1].axis('off')

fig.suptitle('input vs grayscaled image')

plt.tight_layout()
plt.show()

In [None]:
result = run_image_through_network(
    image_grayscale=gray_image,
    model_path=Path('best_model.pth'),
)

In [None]:
vis = draw_bboxes_on_image(result['model_input_image'], result['aabbs'])

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

inp = result['model_input_image']
axes[0].imshow(inp, cmap='gray')
axes[0].set_title(f'NN input @ {inp.shape}')
axes[0].axis('off')

pred = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)
axes[1].imshow(pred, cmap='gray')
axes[1].set_title(f'NN prediction @ {pred.shape}')
axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# TODO: Move fct into `my_code` module
# TODO: Reformat into old image size -> scale up only bouding boxes and overlay on originally loaded image; maybe also on rescaled returned image, just to make sure
# TODO: Put into gradio app