# Setup

In [None]:
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
from jupyter_bbox_widget import BBoxWidget
from plotly.express import imshow as imshow_interactive

In [None]:
import sys
import pathlib

REPO_DIR = pathlib.Path(__name__).resolve().parent.parent
SRC_DIR = str(REPO_DIR / "src")
sys.path.append(SRC_DIR)

In [None]:
from scanplot.io import load_image, dump_coords_csv
from scanplot.plotting import draw_image, draw_ROI
from scanplot.view import (
    CoordinatesMapper, DetectorWidget, MarkerSelectorBBoxWidget, ROISelectorBBoxWidget
)

from scanplot.core.process_template import extract_markers_from_image

from scanplot.core import (
    template_match,
    replace_black_pixels,
    normalize_map,
    image_tresholding,
    get_template_mask,
    center_object_on_template_image,
    reconstruct_template_mask,
    generalized_hough_transform,
    CoordinatesConverter,
    bboxes_to_roi,
    _apply_roi,
)

In [None]:
from scanplot.core import Plot

# Algorithm Pipeline

## 1. Upload plot image

Input data examples can be found in `<REPOSITORY>/datasets/`.

In [None]:
data_path = pathlib.Path("../datasets/")
plot_number = 37
plot_image_filepath = data_path / "plot_images" / f"plot{plot_number}.png"

plot = Plot(plot_image_filepath)
plot.draw()

## 2. Select markers on image


Specify how much marker types on image

In [None]:
plot.set_markers_number(2)

In [None]:
marker_selector_widget = MarkerSelectorBBoxWidget(
    image_data=plot.data,
    markers_number=plot.markers_number,
)

In [None]:
display(marker_selector_widget)

Validate chosen bounding boxes and extract marker images

In [None]:
marker_selector_widget.validate_bboxes()

plot.extract_markers(marker_bboxes=marker_selector_widget.bboxes)

In [None]:
## draw extracted markers

# for marker in plot.markers:
#     draw_image(marker)
#     plt.show()

## 3. (optional step) Select region of interest

For each marker type you can specify its own region of interest.
Default ROI is a whole image.

In [None]:
roi_widget = ROISelectorBBoxWidget(
    image_data=plot.data, 
    markers_number=plot.markers_number,
)

In [None]:
display(roi_widget)

In [None]:
# roi_widget.bboxes

In [None]:
plot.apply_roi(roi_bboxes=roi_widget.bboxes)

In [None]:
# TODO

# ## draw obtained ROI

# plt.subplot(1, 2, 1)
# draw_image(plot.data)
# draw_ROI(_pl)

# plt.subplot(1, 2, 2)
# draw_image(plot.data)
# draw_ROI(roi_list[1])

## 4. Run matching algorithms

In [None]:
correlation_maps = plot_image.run_matching()

## inside: preprocessing

## 5. Select algorithm parameters

In [None]:
## input: 
# 1) list of corr correlation_maps
# 2) list of preprocessed images
# 3) list of preprocessed marker images (only for shapes)

detector = DetectorWidget(
    source_image=plot_image.data,
    correlation_maps_list=correlation_maps,
    markers_list=plot_image.markers
)

In [None]:
widget_settings = {
    "fig_size": 9,
    "marker_size": 70,
    "marker_color": "yellow",
    "marker_type": "*",
}

detector_widget = detector.main_widget(**widget_settings)

In [None]:
display(detector_widget)

In [None]:
all_markers_detections = detector.get_detections()
marker1_detections = detector.get_detections_for_marker(marker="marker1")

# points coordinates in pixels (!)
x = pixel_detections.x
y = pixel_detections.y
plt.scatter(x, y)

## 5. Convert obtained coordinates from pixel to real values

In [None]:
mapper = CoordinatesMapper(source_plot_image)

fig_size = 10
mapper_widget = mapper.interactive_widget(fig_size=fig_size)

In [None]:
display(mapper_widget)

Convert pixel coordinates to factual coordinates

In [None]:
converter = CoordinatesConverter()
converter.import_parameters_from_mapper(mapper)

In [None]:
x_px = marker1_detections.x
y_px = marker2_detections.y

x_factual, y_factual = converter.from_pixel(x_pixel=x_px, y_pixel=y_px)

In [None]:
plt.scatter(x_factual, y_factual)

In [None]:
PLOT_NUMBER = 37
MARKER_NUMBER = 2

PLOT_PATH = DATA_PATH / "plot_images" / f"plot{PLOT_NUMBER}.png"
TEMPLATE_PATH = DATA_PATH / "marker_images" / f"plot{PLOT_NUMBER}_marker{MARKER_NUMBER}.png"

In [None]:
source_plot_image = load_image(PLOT_PATH)
source_template_image = load_image(TEMPLATE_PATH)

plot_image = np.copy(source_plot_image)
template_image = np.copy(source_template_image)

In [None]:
plt.subplot(1, 2, 1)
draw_image(source_plot_image)
plt.title("Source scatter plot image")

plt.subplot(1, 2, 2)
draw_image(source_template_image)
plt.title("Chosen marker")

## 2. (optional step) Select region of interest

Default ROI is a whole image

In [None]:
roi_widget = BBoxWidget(
    hide_buttons=True,
    classes=["Region of interest"],
    image_bytes=cv.imencode(".png", source_plot_image)[1].tobytes(),
    colors=["green"],
)

In [None]:
display(roi_widget)

In [None]:
roi = bboxes_to_roi(source_plot_image, roi_widget.bboxes)
plot_image = apply_roi(source_plot_image, roi)

draw_image(source_plot_image)
draw_ROI(roi)

## 3. Run matching algorithms

3.1. Preprocess plot image and template image

In [None]:
template_mask_initial = get_template_mask(source_template_image)

template_image, template_mask = center_object_on_template_image(
    source_template_image, template_mask_initial
)

plot_image = replace_black_pixels(plot_image, value=10)
template_image = replace_black_pixels(template_image, value=10)

# additional_template_mask = reconstruct_template_mask(template_mask)

3.2. Run template matching algorithm and compute correlatoin map

In [None]:
correlation_map, _ = template_match(
    plot_image, template_image, template_mask, norm_result=True
)

# correlation_map_additional, _ = template_match(
#     plot_image, template_image, additional_template_mask, norm_result=True
# )

3.3. Run Hough transform algorithm and compute accumulator array

In [None]:
accumulator = generalized_hough_transform(
    plot_image, template_image, norm_result=True, crop_result=True
)

assert correlation_map.shape == accumulator.shape

3.4. Combine correlation map obtained by template matching algorithm and accumulator array from Hough Transform

In [None]:
correlation_map_with_hough = correlation_map + 0.6 * accumulator
correlation_map_with_hough = normalize_map(correlation_map_with_hough)

# correlation_map_combined = correlation_map + 0.7 * correlation_map_additional
# correlation_map_combined = normalize_map(correlation_map_combined)

## 4. Select algorithm parameters

The algorithm has 2 parameters:
- Points Number
- Points Density

Learn more about parameters selection in [documentation](https://github.com/adusachev/scanplot/blob/master/docs/user_manual.md#detector).

In [None]:
detector = DetectorWidget(
    source_image=source_plot_image,
    template=template_image,
    correlation_map=correlation_map_with_hough
)

In [None]:
widget_settings = {
    "fig_size": 9,
    "marker_size": 70,
    "marker_color": "yellow",
    "marker_type": "*",
}

detector_widget = detector.main_widget(**widget_settings)

In [None]:
display(detector_widget)

In [None]:
detected_points_px = detector.get_detections()

# points coordinates in pixels (!)
x = detected_points_px[:, 0]
y = detected_points_px[:, 1]
plt.scatter(x, y)

## 5. Convert obtained coordinates from pixel to real values

Map pixel coordinates to factual coordinates

In [None]:
mapper = CoordinatesMapper(source_plot_image)

In [None]:
fig_size = 10

mapper_widget = mapper.interactive_widget(fig_size=fig_size)
display(mapper_widget)

Convert pixel coordinates to factual coordinates

In [None]:
converter = CoordinatesConverter()
converter.import_parameters_from_mapper(mapper)

In [None]:
x_px = detected_points_px[:, 0]
y_px = detected_points_px[:, 1]

x_factual, y_factual = converter.from_pixel(x_pixel=x_px, y_pixel=y_px)

In [None]:
plt.scatter(x_factual, y_factual)

Save obtained coordinates in csv

In [None]:
dump_coords_csv(
    x=x_factual,
    y=y_factual,
    savepath=f"detections_plot{PLOT_NUMBER}_marker{MARKER_NUMBER}.csv"
)