In [None]:
import cv2
import numpy as np
import os
from abc import ABC, abstractmethod
import pickle
from skimage.feature import match_template
import matplotlib.pyplot as plt
import pandas as pd
import imageio as iio

In [None]:
# naming for positions jpg: YY/MM/DD_Positions_CamID.jpg
# naming for template jpg: YY/MM/DD_CamID_ObjectToFind_level1_template.jpg

In [None]:
# todo: Push feature/template_matching

## Class for single camera template matching using opencv or skimage

In [None]:
class TemplateMatching(ABC):
    
    
    @property
    def template_directory(self) ->str:
        # specifies the directory for originally used templates
        return r"/Users/kobel/Documents/Medizin/Doktorarbeit/Coding/templates_obj/"
    
    
    @property
    def template_naming(self) ->str:
        # specifies the naming of the templates
        return 'template.jpg'
    
    
    @property
    def template_matching_threshold(self) ->float:
        return 0.95
    
    
    @property
    def matching_naming(self) -> str:
        # specifies the naming of the jpgs to be matched
        return '_Positions'
    
    
    def __init__(self, directory_positions_jpg: str, cam_id: str, visualize_matching: bool = False, use_open_cv: bool = True):
        
        # set path positions.jpgs from input
        self.directory_positions_jpg = directory_positions_jpg
        self.visualize_matching = visualize_matching
        if self.visualize_matching:
            self.visualization = {}
        self.cam_id = cam_id
        self.use_open_cv = use_open_cv
        self.markers_to_remove = []
                    
            
    def export_as_DLC_h5(self) -> None:
        df = self._construct_dlc_output_style_df_from_manual_marker_coords()
        if self.visualize_matching:
            print(df)
        output_filepath = f'{self.cam_id}_templatematched_test_position_marker_fake.h5'
        df.to_hdf(output_filepath, "df")
        
        
    def run(self) -> dict:    
        print(f'Now analyzing {self.cam_id}!')
        self.path_positions = self._get_positions_path(path = self.directory_positions_jpg, naming_flag = self.matching_naming)
        self.paths_template = self._get_template_paths(path = self.template_directory, naming_flag = self.template_naming)

        original_image = iio.get_reader(self.path_positions).get_data(0)

        for marker_id in self.paths_template.keys():

            if self.visualize_matching:
                self.visualization[marker_id] = []

            # level 1
            image = original_image.copy()
            template, location = self._iterate_template_matching(camera = self.cam_id, image = image, level = str(1), marker_id = marker_id)

            # level 2
            offset = location
            image = image[location[1]:location[1]+template.shape[0], location[0]:location[0]+template.shape[1]]
            template, location = self._iterate_template_matching(image = image, camera = self.cam_id, level = str(2), marker_id = marker_id)

            if '3' in self.paths_template[marker_id].keys():
                # level 3
                offset = (offset[0] + location[0], offset[1] + location [1])
                image = image[location[1]:location[1]+template.shape[0], location[0]:location[0]+template.shape[1]]
                template, location = self._iterate_template_matching(image = image, camera = self.cam_id, level = str(3), marker_id = marker_id)

                if '4' in self.paths_template[marker_id].keys():
                    # level 4
                    offset = (offset[0] + location[0], offset[1] + location [1])
                    image = image[location[1]:location[1]+template.shape[0], location[0]:location[0]+template.shape[1]]
                    template, location = self._iterate_template_matching(image = image, camera = self.cam_id, level = str(4), marker_id = marker_id)

            if marker_id not in self.markers_to_remove:
                coordinates = self._get_mean_coordinates(image, template, location)
                transposed_coordinates = self._transpose_coordinates(coordinates, offset)
                self._add_templatematched_test_position_marker(marker_id = marker_id, x_or_column_idx = transposed_coordinates[0], y_or_row_idx = transposed_coordinates[1], likelihood = 0.9999, overwrite = False)
      
        for marker_id in self.markers_to_remove:
            self.paths_template.pop(marker_id)
            
        if self.visualize_matching:
            self._visualize_predictions()
                
            
    def _add_templatematched_test_position_marker(self, marker_id: str, x_or_column_idx: int, y_or_row_idx: int, likelihood: float, overwrite: bool=False) -> None:
        if hasattr(self, 'manual_test_position_marker_coords_pred') == False:
            self.manual_test_position_marker_coords_pred = {}
        if (marker_id in self.manual_test_position_marker_coords_pred.keys()) & (overwrite == False):
            raise ValueError('There are already coordinates for the marker you '
                             f'tried to add: "{marker_id}: {self.manual_test_position_marker_coords_pred[marker_id]}'
                             '". If you would like to overwrite these coordinates, please pass '
                             '"overwrite = True" as additional argument to this method!')
        self.manual_test_position_marker_coords_pred[marker_id] = {'x': [x_or_column_idx], 'y': [y_or_row_idx], 'likelihood': [likelihood]}
    
    
    def _construct_dlc_output_style_df_from_manual_marker_coords(self) -> pd.DataFrame:
        multi_index = self._get_multi_index()
        df = pd.DataFrame(data = {}, columns = multi_index)
        for scorer, marker_id, key in df.columns:
            df[(scorer, marker_id, key)] = self.manual_test_position_marker_coords_pred[marker_id][key]
        return df
    
    
    def _create_matched_template_plot(self, image: np.array, template: str, location: tuple[int, int], accuracy: float) -> tuple[np.array, list, float]:
        height_template, width_template = template.shape[0], template.shape[1]
        rectangle = [location[0], location[0]+width_template, location[0]+width_template, location[0], location[0]], [location[1], location[1], location[1]+height_template, location[1]+height_template, location[1]]

        return (image, rectangle, accuracy)
        
        
    def _evaluate_template_matching(self, accuracy: int) -> bool:
        if accuracy > self.template_matching_threshold:
            return True
        else:
            return False
    
    
    def _get_mean_coordinates(self, image: str, template: np.array, location: tuple[int, int]) -> tuple[int, int]:
        height_template, width_template = template.shape[0:2]
        center_coordinates = (location[0] + width_template//2, location[1] + height_template//2)
        return center_coordinates
      
    
    def _get_multi_index(self) -> pd.MultiIndex:
        multi_index_column_names = [[], [], []]
        for marker_id in self.manual_test_position_marker_coords_pred.keys():
            for column_name in ("x", "y", "likelihood"):
                multi_index_column_names[0].append("templatematched_marker_positions")
                multi_index_column_names[1].append(marker_id)
                multi_index_column_names[2].append(column_name)
        return pd.MultiIndex.from_arrays(multi_index_column_names, names=('scorer', 'bodyparts', 'coords'))
    
    
    def _get_positions_path(self, path: str, naming_flag: str) -> str:  
        for elem in os.listdir(path):
            if naming_flag in elem and self.cam_id in elem:
                filepath = path + elem
        if "filepath" not in locals(): 
            raise FileNotFoundError(f'Found no suiting _Positions file for {self.cam_id}!')
            
        return filepath
    
    
    def _get_template_paths(self, path: str, naming_flag: str) -> dict:
        all_filepaths = [elem for elem in os.listdir(path) if naming_flag in elem and self.cam_id in elem]
        if len(all_filepaths) == 0:
            raise FileNotFoundError(f'Found no suiting templates for {self.cam_id}')
            
        files = {}
        objectstofind = []
        for filepath in all_filepaths:
            split_name = filepath.split("_")
            objecttofind, level = split_name[-3], split_name[-2]
            if objecttofind not in objectstofind:
                objectstofind.append(objecttofind)
        for objecttofind in objectstofind:
            files[objecttofind] = {}             
            
            filepaths_objects = [elem for elem in all_filepaths if str(objecttofind) in elem]
            levels = [elem[elem.index('level')+5:elem.index('level')+6] for elem in all_filepaths if str(objecttofind) in elem]
            for i in range(len(levels)):
                files[objecttofind].update({levels[i]: path + filepaths_objects[i]})

        return files
      
                                 
    def _iterate_template_matching(self, camera: str, image: str, level: str, marker_id: str) -> tuple[str, tuple[int, int]]:
        
        self.template_matched = False
        template = iio.imread(self.paths_template[marker_id][level])
        
        accuracy, location = self._match_template(image = image, template = template)
            
        if self._evaluate_template_matching(accuracy):
            if self.visualize_matching:
                self.visualization[marker_id].append(self._create_matched_template_plot(image = image, template = template, location = location, accuracy = accuracy))
            return template, location
                    
        else:
            if marker_id not in self.markers_to_remove:
                self.markers_to_remove.append(marker_id)
                print(f'couldnt find a match for {marker_id}\n'
                  'you need to set it manually')
            return template, (1, 1)

            
    def _match_template(self, image: str, template: str) -> tuple[int, tuple[int, int]]:
        img = image.copy()
        accuracy, location = None, None
        if self.use_open_cv:
            method = cv2.TM_CCORR_NORMED #cv2.TM_SQDIFF_NORMED, cv2.TM_CCOEFF_NORMED

            result = cv2.matchTemplate(image = img, templ = template, method = method)
            min_value, max_value, min_location, max_location = cv2.minMaxLoc(result)

            if method == cv2.TM_SQDIFF_NORMED:
                # the normed-/sqdiff methods have the position of the matched template in the min location
                location = min_location
                value = min_value
                accuracy = 1-abs(0-value)
            else:
                # All other methods show the matched template position in the max location
                location = max_location
                value = max_value
                accuracy = value
        else:
            result = match_template(image, template)
            ij = np.unravel_index(np.argmax(result), result.shape[0:2])
            x,y = ij[::-1]
            location = (x,y)
            #built-in method in skimage to get accuracy???
        return  accuracy, location
    
    
    def _transpose_coordinates(self, coordinates: tuple[int, int], offset: tuple[int, int]) -> tuple[int, int]:
        return (coordinates[0] + offset[0], coordinates[1] + offset[1])
    
    
    def _visualize_predictions(self)-> None:
        fig = plt.figure(figsize = (20, 10))

        original_image = iio.get_reader(self.path_positions).get_data(0)
        plt.imshow(original_image)
        for marker_id in self.paths_template.keys():
            plt.scatter(self.manual_test_position_marker_coords_pred[marker_id]['x'], self.manual_test_position_marker_coords_pred[marker_id]['y'])
            plt.text(self.manual_test_position_marker_coords_pred[marker_id]['x'][0], self.manual_test_position_marker_coords_pred[marker_id]['y'][0], marker_id)
        plt.show()

        for marker_id in self.paths_template.keys():
            fig = plt.figure(figsize=(10, 5))
            gs = fig.add_gridspec(1, len(self.visualization[marker_id]))
            plt.suptitle(marker_id, x=0.5, y=0.7, weight = 'bold', size = 'x-large')

            n=0
            for image, rectangle, accuracy in self.visualization[marker_id]:
                ax = fig.add_subplot(gs[0, n])
                n += 1
                plt.imshow(image)
                plt.plot(rectangle[0], rectangle[1])
                plt.title(label = f'level{n}')
                ax.xaxis.set_label_position('top')
                if self.use_open_cv:
                    ax.set_xlabel(f'Accuracy: {accuracy}')

            plt.show()

In [None]:
Bottom_220826 = TemplateMatching(directory_positions_jpg = '/Users/kobel/Documents/Medizin/Doktorarbeit/Coding/positions/220826/', cam_id = "Bottom", visualize_matching= True, use_open_cv = True)
Bottom_220826.run()
Top_220826 = TemplateMatching(directory_positions_jpg = '/Users/kobel/Documents/Medizin/Doktorarbeit/Coding/positions/220826/', cam_id = "Top", visualize_matching= True, use_open_cv = True)
Top_220826.run()
Side1_220826 = TemplateMatching(directory_positions_jpg = '/Users/kobel/Documents/Medizin/Doktorarbeit/Coding/positions/220826/', cam_id = "Side1", visualize_matching= True, use_open_cv = True)
Side1_220826.run()
Side2_220826 = TemplateMatching(directory_positions_jpg = '/Users/kobel/Documents/Medizin/Doktorarbeit/Coding/positions/220826/', cam_id = "Side2", visualize_matching= True, use_open_cv = True)
Side2_220826.run()
Ground1_220826 = TemplateMatching(directory_positions_jpg = '/Users/kobel/Documents/Medizin/Doktorarbeit/Coding/positions/220826/', cam_id = "Ground1", visualize_matching= True, use_open_cv = True)
Ground1_220826.run()
Ground2_220826 = TemplateMatching(directory_positions_jpg = '/Users/kobel/Documents/Medizin/Doktorarbeit/Coding/positions/220826/', cam_id = "Ground2", visualize_matching= True, use_open_cv = True)
Ground2_220826.run()

## Konstantins image preprocessing before template matching (not implemented yet)
## -> Since the templates are not preprocessed, it doesn't make a lot of sense to include preprocessing here...?

In [None]:
from PIL import Image, ImageEnhance, ImageOps

def image_preprocess(img, kernel_name):
    #contrast enhancement
    img = Image.fromarray(img)
    img_contr_obj=ImageEnhance.Contrast(img)
    img=img_contr_obj.enhance(2)
    img = np.array(img)
    
    
    #kernel application
    img = np.array(img)
    if kernel_name == "sharpening":
        kernel = np.array([[0, -1, 0],
                        [-1, 5, -1],
                        [0, -1, 0]])
    elif kernel_name == "edge_detection":
         kernel = np.array([[0, -1, 0],
                        [-1, 4, -1],
                        [0, -1, 0]])
    
    kerneled_single_color_frames = []
    for rgb_index in range(3):
        image_convolved = convolve2d(img[:,:,rgb_index], kernel, 'valid')
        kerneled_single_color_frames.append(image_convolved)

    kerneled_frame = np.asarray(kerneled_single_color_frames)
    kerneled_frame = np.moveaxis(kerneled_frame, 0, -1)
    
    img = kerneled_frame
    
    #binarization
    mean = (img.max() + img.min())/2
    img = np.where(img < mean, img, 255) 
    img = np.where(img > mean, img, 0)
    img = img.astype("uint8")