# Utils

In [None]:
import tensorflow as tf
from typing import Union, TypeVar
from collections import namedtuple
import cv2

PI = tf.cast(
    tf.math.angle(tf.constant(-1, dtype=tf.complex64)), tf.float32
)

KT = TypeVar('KT', bound='KeyPoints')
unpacked_octave = namedtuple('unpacked_octave', 'octave, layer, scale')


class Octave:
    def __init__(
            self,
            index: int,
            gss: tf.Tensor
    ):
        self.__shape = gss.get_shape().as_list()
        self.index = index
        self.gss = gss
        self.magnitude, self.orientation = compute_mag_ori(gss)

    @property
    def shape(self) -> list:
        return self.__shape


class KeyPoints:
    def __init__(
            self,
            pt: Union[None, tf.Tensor] = None,
            size: Union[None, tf.Tensor] = None,
            angle: Union[None, tf.Tensor] = None,
            octave: Union[None, tf.Tensor] = None,
            response: Union[None, tf.Tensor] = None,
            as_image_size: bool = False
    ):
        self.scale_index = None
        self.__shape = (0,)
        self.pt = tf.constant([[]], shape=(0, 3), dtype=tf.float32)
        self.size = tf.constant([[]], shape=(0, 1), dtype=tf.float32)
        self.angle = tf.constant([[]], shape=(0, 1), dtype=tf.float32)
        self.octave = tf.constant([[]], shape=(0, 1), dtype=tf.float32)
        self.response = tf.constant([[]], shape=(0, 1), dtype=tf.float32)
        self.as_image_size = as_image_size
        if pt is not None: self.__constructor(pt, size, angle, octave, response)
        self.__n_batch = None

    def __add__(
            self,
            other: KT
    ) -> KT:
        if not isinstance(other, KeyPoints): raise TypeError
        if self.as_image_size ^ other.as_image_size: raise ValueError('the as_image_size parameter not inconsistent')
        if (self.scale_index is None) ^ (other.scale_index is None):
            if self.scale_index is None:
                self.scale_index = tf.ones((self.shape[0], 1), tf.float32) * -1
            else:
                other.scale_index = tf.ones((other.shape[0], 1), tf.float32) * -1
        ints = self.from_array(tf.concat((self.as_array(), other.as_array()), axis=0), inplace=False)
        return ints

    def __iadd__(
            self,
            other: KT
    ) -> KT:
        if not isinstance(other, KeyPoints): raise TypeError
        if self.as_image_size ^ other.as_image_size: raise ValueError('the as_image_size parameter not inconsistent')
        if (self.scale_index is None) ^ (other.scale_index is None):
            if self.scale_index is None:
                self.scale_index = tf.ones((self.shape[0], 1), tf.float32) * -1
            else:
                other.scale_index = tf.ones((other.shape[0], 1), tf.float32) * -1
        self.from_array(tf.concat((self.as_array(), other.as_array()), axis=0), inplace=True)
        return self

    def __constructor(
            self,
            pt,
            size,
            angle,
            octave,
            response
    ):
        if not isinstance(pt, tf.Tensor): raise ValueError('All the fields need to be type of tf.Tensor')
        _shape = pt.get_shape().as_list()
        if len(_shape) > 2:
            pt = tf.squeeze(pt)
            _shape = pt.get_shape().as_list()
        if len(_shape) != 2 or _shape[-1] < 3: raise ValueError(
            'expected "pt" to be 2D tensor with size of (None, 3 or 4)')
        if _shape[-1] == 4:
            pt, scale_index = tf.split(pt, [3, 1], axis=-1)
        else:
            scale_index = None
        valid = [pt]
        for f in [size, angle, octave, response]:
            if not isinstance(f, tf.Tensor): raise ValueError('All the fields need to be type of tf.Tensor')
            f = tf.reshape(f, (-1, 1))
            if f.get_shape()[0] != _shape[0]: raise ValueError('All the fields need to be with the same first dim size')
            valid.append(f)
        self.pt, self.size, self.angle, self.octave, self.response = valid
        self.scale_index = scale_index
        self.__shape = (_shape[0],)
        self.__n_batch = None

    @property
    def shape(self) -> tuple:
        return self.__shape

    def as_array(self) -> tf.Tensor:
        _array = [self.pt]
        _array += [self.scale_index] if self.scale_index is not None else []
        _array += [self.size, self.angle, self.octave, self.response]
        return tf.concat(_array, axis=-1)

    def from_array(
            self,
            array:
            tf.Tensor,
            inplace=False
    ) -> Union[None, KT]:
        _shape = array.get_shape().as_list()
        if len(_shape) != 2 or _shape[1] < 7: raise ValueError('array rank need to be 2 with size of (None, 7 or 8)')
        splits = [4] if _shape[1] == 8 else [3]
        splits += [1, 1, 1, 1]
        split = tf.split(array, splits, axis=-1)
        if not inplace: return KeyPoints(*split, as_image_size=self.as_image_size)
        self.__constructor(*split)

    def to_image_size(
            self,
            inplace=False
    ) -> Union[None, KT]:
        if self.shape[0] == 0 or self.as_image_size: return self if not inplace else None
        pt_unpack = self.pt * tf.constant([1.0, 0.5, 0.5], dtype=tf.float32)
        size_unpack = self.size * 0.5
        octave_unpack = tf.cast(self.octave, dtype=tf.int64) ^ 255
        if inplace:
            self.pt, self.size, self.octave = pt_unpack, size_unpack, octave_unpack
            self.as_image_size = True
            return
        if self.scale_index is not None: pt_unpack = tf.concat((pt_unpack, self.scale_index), -1)
        unpack_key_points = KeyPoints(
            pt_unpack, size_unpack, self.angle, tf.cast(octave_unpack, dtype=tf.float32), self.response, True
        )
        return unpack_key_points

    def relies_scale_index(self):
        self.scale_index = None

    def unpack_octave(self) -> unpacked_octave:
        if self.shape[0] == 0: return unpacked_octave(None, None, None)
        up_key_points = self.to_image_size(inplace=False) if not self.as_image_size else self

        octave_unpack = tf.cast(up_key_points.octave, tf.int64)
        octave = octave_unpack & 255
        octave = (octave ^ 255) - 1

        layer = tf.bitwise.right_shift(octave_unpack, 8)
        layer = layer & 255

        scale = tf.where(
            octave >= 0, tf.cast(1 / tf.bitwise.left_shift(1, octave), dtype=tf.float32),
            tf.cast(tf.bitwise.left_shift(1, -octave), dtype=tf.float32)
        )
        octave = octave + 1
        return unpacked_octave(tf.cast(octave, dtype=tf.float32), tf.cast(layer, dtype=tf.float32), scale)

    def partition_by_batch(
            self,
            descriptors: Union[tf.Tensor, None] = None
    ) -> tuple[list[KT], Union[tf.Tensor, None]]:
        if self.shape[0] == 0: return None
        part = tf.reshape(tf.cast(tf.split(self.pt, [1, 2], -1)[0], tf.int32), (-1,))
        if descriptors is not None:
            descriptors = tf.dynamic_partition(descriptors, part, tf.reduce_max(part) + 1)
        part = tf.dynamic_partition(self.as_array(), part, tf.reduce_max(part) + 1)
        out = [self.from_array(p, inplace=False) for p in part]
        return out, descriptors

    def partition_by_index(
            self,
            partition_index: tf.Tensor
    ) -> list[KT]:
        if self.shape[0] == 0: return None
        if partition_index.get_shape()[0] != self.shape[0]:
            raise ValueError('partition_index shape not equal to key points shape')
        part = tf.dynamic_partition(self.as_array(), partition_index, tf.reduce_max(partition_index) + 1)
        out = [self.from_array(p, inplace=False) for p in part]
        return out

    def n_batches(self) -> int:
        if self.__n_batch is not None: return self.__n_batch
        batch = tf.split(self.pt, [1, 2], -1)[0]
        self.__n_batch = int(tf.reduce_max(batch)) - int(tf.reduce_min(batch)) + 1
        return self.__n_batch


def gaussian_kernel(
        kernel_size: int,
        sigma: Union[None, float] = None
) -> tf.Tensor:
    if kernel_size == 0 and (sigma is None or sigma < 0.8):
        if sigma is None:
            raise ValueError('need sigma parameter when the kernel size is 0')
        raise ValueError('minimum kernel need to be size of 3 --> sigma > 0.8')

    if kernel_size == 0:
        kernel_size = ((((sigma - 0.8) / 0.3) + 1) * 2) + 1
        kernel_size = kernel_size + 1 if (kernel_size % 2) == 0 else kernel_size

    assert kernel_size % 2 != 0 and kernel_size > 2

    if sigma is None:
        sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8

    ax = tf.range(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
    xx, yy = tf.meshgrid(ax, ax)
    normal = 1 / (2.0 * PI * (sigma ** 2))
    kernel = tf.exp(
        -((xx ** 2) + (yy ** 2)) / (2.0 * (sigma ** 2))
    ) * normal
    return kernel / tf.reduce_sum(kernel)


def make_neighborhood2D(
        init_cords: tf.Tensor,
        con: int = 3,
        origin_shape: Union[None, tuple, list, tf.TensorShape] = None
) -> tf.Tensor:
    if not isinstance(init_cords, tf.Tensor): raise TypeError("cords need to be of type Tensor")
    B, ndim = init_cords.get_shape()
    con = int(con)
    assert ndim == 4

    ax = tf.range(-con // 2 + 1, (con // 2) + 1, dtype=tf.int64)

    con_kernel = tf.stack(tf.meshgrid(ax, ax)[::-1], axis=-1)

    con_kernel = tf.reshape(con_kernel, shape=(1, con ** 2, 2))

    b, yx, d = tf.split(init_cords, [1, 2, 1], axis=1)
    yx = yx[:, tf.newaxis, ...]

    yx = yx + con_kernel

    b = tf.repeat(b[:, tf.newaxis, ...], repeats=con ** 2, axis=1)
    d = tf.repeat(d[:, tf.newaxis, ...], repeats=con ** 2, axis=1)

    neighbor = tf.concat((b, yx, d), axis=-1)
    if origin_shape is None:
        return neighbor

    assert len(origin_shape) == 4
    neighbor = neighbor + 1
    b, y, x, d = tf.unstack(neighbor, num=4, axis=-1)

    y_cast = tf.logical_and(tf.math.greater_equal(y, 1), tf.math.less_equal(y, origin_shape[1]))
    x_cast = tf.logical_and(tf.math.greater_equal(x, 1), tf.math.less_equal(x, origin_shape[2]))

    valid = tf.cast(tf.logical_and(y_cast, x_cast), dtype=tf.int32)
    valid = tf.math.reduce_prod(valid, axis=-1)
    cords_valid = tf.where(valid == 1)
    neighbor = tf.gather_nd(neighbor, cords_valid) - 1
    return neighbor


def compute_extrema3D(
        X: tf.Tensor,
        threshold: Union[tf.Tensor, float, None] = None,
        con: Union[tf.Tensor, int, tuple, list] = 3,
        border_width: Union[tf.Tensor, tuple, list, None] = None,
        epsilon: Union[tf.Tensor, float] = 1e-07
) -> tf.Tensor:
    if not isinstance(X, tf.Tensor): raise TypeError("X need to be of type Tensor")
    _shape = X.get_shape().as_list()
    _n_dims = len(_shape)
    if _n_dims != 4:
        raise ValueError(
            'expected the inputs to be 4D tensor with size of (None, H, W, C)'
        )
    b, h, w, d = tf.unstack(tf.cast(_shape, dtype=tf.int64), num=4, axis=-1)

    X = tf.cast(X, dtype=tf.float32)

    threshold = tf.cast(threshold, dtype=tf.float32) if threshold is not None else None

    if tf.is_tensor(con):
        con = tf.get_static_value(tf.reshape(con, shape=(-1,)))
        con = tuple(con) if len(con) != 1 else int(con)

    if isinstance(con, int):
        con = (con, con, con)

    if len(con) > 3:
        raise ValueError('con parameter need to be int or iterable with size 3')

    half_con = [c // 2 for c in con]

    x_con = tf.concat((tf.expand_dims(X, -1), tf.expand_dims(X, -1) * -1.), -1)

    extrema = tf.nn.max_pool3d(x_con, ksize=con, strides=[1, 1, 1], padding='VALID')

    extrema_max, extrema_min = tf.unstack(extrema, 2, -1)
    extrema_min = extrema_min * -1.

    compare_array = tf.slice(X, [0, *half_con], [b, h - 2 * half_con[0], w - 2 * half_con[1], d - 2 * half_con[2]])

    def _equal_with_epsilon(arr):
        return tf.logical_and(
            tf.math.greater_equal(arr, compare_array - epsilon),
            tf.math.less_equal(arr, compare_array + epsilon)
        )

    extrema_cond = tf.logical_or(
        _equal_with_epsilon(extrema_max),
        _equal_with_epsilon(extrema_min)
    )
    if threshold is not None:
        extrema_cond = tf.logical_and(extrema_cond, tf.math.greater(tf.abs(compare_array), threshold))

    byxd = tf.where(extrema_cond)

    byxd = byxd + tf.constant([[0] + half_con], dtype=tf.int64)

    if border_width is not None:
        if tf.is_tensor(border_width):
            border_width = tf.get_static_value(tf.reshape(border_width, shape=(-1,)))
            border_width = tuple(border_width)
        if len(border_width) != 3:
            raise ValueError('border_width need to be with len of 3')
        cb, cy, cx, cd = tf.unstack(byxd, num=4, axis=-1)
        by, bx, bd = tf.unstack(tf.cast(border_width, dtype=tf.int64), num=3)
        y_cond = tf.logical_and(tf.math.greater_equal(cy, by), tf.math.less_equal(cy, h - by))
        x_cond = tf.logical_and(tf.math.greater_equal(cx, bx), tf.math.less_equal(cx, w - bx))
        d_cond = tf.logical_and(tf.math.greater_equal(cd, bd), tf.math.less_equal(cd, d - bd))

        casted_ = tf.logical_and(tf.logical_and(y_cond, x_cond), d_cond)
        byxd = tf.boolean_mask(byxd, casted_)

    return byxd


def compute_central_gradient3D(
        X: tf.Tensor
) -> tf.Tensor:
    if not isinstance(X, tf.Tensor): raise TypeError("X need to be of type Tensor")
    _shape = X.get_shape().as_list()
    _n_dims = len(_shape)
    if _n_dims != 4:
        raise ValueError(
            'expected the inputs to be 4D tensor with size of (None, H, W, C)'
        )

    X = tf.cast(X, dtype=tf.float32)

    kx = tf.constant([[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]], dtype=tf.float32)
    kx = tf.pad(
        tf.reshape(kx, shape=(3, 3, 1, 1, 1)),
        paddings=tf.constant([[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]),
        constant_values=0.0
    )
    ky = tf.constant([[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=tf.float32)
    ky = tf.pad(
        tf.reshape(ky, shape=(3, 3, 1, 1, 1)),
        paddings=tf.constant([[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]),
        constant_values=0.0
    )
    kz = tf.zeros_like(kx)
    kz = tf.tensor_scatter_nd_update(kz, tf.constant([[1, 1, 0, 0, 0], [1, 1, 2, 0, 0]]), tf.constant([-1.0, 1.0]))

    kernels_dx = tf.concat((kx, ky, kz), axis=-1)

    X = tf.expand_dims(X, axis=-1)
    grad = tf.nn.convolution(X, kernels_dx, padding='VALID') * 0.5
    return grad


def compute_hessian_3D(
        X: tf.Tensor
) -> tf.Tensor:
    if not isinstance(X, tf.Tensor): raise TypeError("X need to be of type Tensor")
    _shape = X.get_shape().as_list()
    _n_dims = len(_shape)
    if _n_dims != 4:
        raise ValueError(
            'expected the inputs to be 4D tensor with size of (None, H, W, C)'
        )

    X = tf.cast(X, dtype=tf.float32)

    dxx = tf.constant([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]], dtype=tf.float32)
    dxx = tf.pad(
        tf.reshape(dxx, shape=(3, 3, 1, 1, 1)),
        paddings=tf.constant([[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]),
        constant_values=0.0
    )
    dyy = tf.constant([[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]], dtype=tf.float32)
    dyy = tf.pad(
        tf.reshape(dyy, shape=(3, 3, 1, 1, 1)),
        paddings=tf.constant([[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]),
        constant_values=0.0
    )
    dzz = tf.zeros_like(dxx)
    dzz = tf.tensor_scatter_nd_update(
        dzz, tf.constant([[1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 2, 0, 0]]), tf.constant([1.0, -2.0, 1.0])
    )

    kww = tf.concat((dxx, dyy, dzz), axis=-1)

    dxy = tf.constant([[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]], dtype=tf.float32)
    dxy = tf.pad(
        tf.reshape(dxy, shape=(3, 3, 1, 1, 1)),
        paddings=tf.constant([[0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]),
        constant_values=0.0
    )

    dxz = tf.zeros_like(dxy)
    dxz = tf.tensor_scatter_nd_update(
        dxz,
        tf.constant([[1, 0, 0, 0, 0], [1, 2, 2, 0, 0], [1, 0, 2, 0, 0], [1, 2, 0, 0, 0]]),
        tf.constant([1.0, 1.0, -1.0, -1.0])
    )

    dyz = tf.zeros_like(dxy)
    dyz = tf.tensor_scatter_nd_update(
        dyz,
        tf.constant([[0, 1, 0, 0, 0], [2, 1, 2, 0, 0], [0, 1, 2, 0, 0], [2, 1, 0, 0, 0]]),
        tf.constant([1.0, 1.0, -1.0, -1.0])
    )

    kws = tf.concat((dxy, dyz, dxz), axis=-1)

    X = tf.expand_dims(X, axis=-1)

    dFww = tf.nn.convolution(X, kww, padding='VALID')

    dFws = tf.nn.convolution(X, kws, padding='VALID') * 0.25

    dxx, dyy, dzz = tf.unstack(dFww, 3, axis=-1)
    dxy, dyz, dxz = tf.unstack(dFws, 3, axis=-1)
    hessian_mat = tf.stack(
        (
            tf.stack((dxx, dxy, dxz), axis=-1),
            tf.stack((dxy, dyy, dyz), axis=-1),
            tf.stack((dxz, dyz, dzz), axis=-1)
        ), axis=-1
    )
    return hessian_mat


def compute_mag_ori(
        gss: tf.Tensor
) -> tuple[tf.Tensor, tf.Tensor]:
    if not isinstance(gss, tf.Tensor): raise TypeError("gss need to be of type Tensor")
    kx = tf.constant([[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]], shape=(3, 3, 1, 1, 1), dtype=tf.float32)
    ky = tf.constant([[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], shape=(3, 3, 1, 1, 1), dtype=tf.float32)
    gradient_kernel = tf.concat((kx, ky), axis=-1)

    gradient = tf.nn.convolution(tf.expand_dims(gss, -1), gradient_kernel, padding='VALID')
    dx, dy = tf.unstack(gradient, 2, axis=-1)

    magnitude = tf.math.sqrt(dx * dx + dy * dy)
    orientation = tf.math.atan2(dy, dx) * (180.0 / PI)

    return magnitude, orientation


def load_image(
        name: str,
        color_mode: str = 'grayscale'
) -> tf.Tensor:
    im = tf.keras.utils.load_img(name, color_mode=color_mode)
    im = tf.convert_to_tensor(tf.keras.utils.img_to_array(im), dtype=tf.float32)
    return im[tf.newaxis, ...]


def templet_matching_TF(
        scr_kp: KT,
        dst_kp: KT,
        scr_dsc: tf.Tensor,
        dst_dsc: tf.Tensor,
        ratio_threshold: float = 0.7
) -> tuple[Union[list[tf.Tensor], tf.Tensor], Union[list[tf.Tensor], tf.Tensor]]:
    if not isinstance(scr_kp, KeyPoints) or not isinstance(dst_kp, KeyPoints):
        raise TypeError('Key points need to be of type "KeyPoints"')
    if not isinstance(scr_dsc, tf.Tensor) or not isinstance(dst_dsc, tf.Tensor):
        raise TypeError('descriptors need to be of type "Tensor"')
    if scr_kp.n_batches() > 1: raise ValueError("number of batches in the templet key points > 1")
    if dst_kp.n_batches() > 1:
        dst_kp_, dst_dsc_ = dst_kp.partition_by_batch(descriptors=dst_dsc)
        out_src_pt = []
        out_dst_pt = []
        for kpt, dsc in zip(dst_kp_, dst_dsc_):
            src_pt_, dst_pt_ = templet_matching_TF(scr_kp, kpt, scr_dsc, dsc)
            out_src_pt.append(src_pt_)
            out_dst_pt.append(dst_pt_)
        return out_src_pt, out_dst_pt

    diff = tf.transpose(tf.expand_dims(scr_dsc, 0), (0, 2, 1)) - tf.expand_dims(dst_dsc, -1)
    diff = tf.norm(diff, ord='euclidean', axis=1)
    diff = tf.transpose(diff, (1, 0))

    _, indices = tf.math.top_k(-diff, k=2)
    values = tf.gather(diff, indices, batch_dims=-1)

    m_dist, n_dist = tf.unstack(values, 2, -1)
    mask = tf.where(m_dist < ratio_threshold * n_dist, True, False)

    des_index = tf.boolean_mask(tf.unstack(indices, 2, -1)[0], mask)
    scr_index = tf.cast(tf.squeeze(tf.where(mask)), tf.int32)

    src_pt = tf.gather(scr_kp.to_image_size().pt, scr_index)
    dst_pt = tf.gather(dst_kp.to_image_size().pt, des_index)
    return src_pt, dst_pt


def templet_matching_CV2(
        scr_kp: KT,
        dst_kp: KT,
        scr_dsc: tf.Tensor,
        dst_dsc: tf.Tensor,
        ratio_threshold: float = 0.7
) -> tuple[tf.Tensor, tf.Tensor]:
    if not isinstance(scr_kp, KeyPoints) or not isinstance(dst_kp, KeyPoints):
        raise TypeError('Key points need to be of type "KeyPoints"')
    if not isinstance(scr_dsc, tf.Tensor) or not isinstance(dst_dsc, tf.Tensor):
        raise TypeError('descriptors need to be of type "Tensor"')
    if scr_kp.n_batches() > 1: raise ValueError("number of batches in the templet key points > 1")
    if dst_kp.n_batches() > 1:
        dst_kp_, dst_dsc_ = dst_kp.partition_by_batch(descriptors=dst_dsc)
        out_src_pt = []
        out_dst_pt = []
        for kpt, dsc in zip(dst_kp_, dst_dsc_):
            src_pt_, dst_pt_ = templet_matching_CV2(scr_kp, kpt, scr_dsc, dsc)
            out_src_pt.append(src_pt_)
            out_dst_pt.append(dst_pt_)
        return out_src_pt, out_dst_pt

    flann = cv2.FlannBasedMatcher(dict(algorithm=0, trees=5), dict(checks=50))

    scr_dsc = scr_dsc.numpy()
    dst_dsc = dst_dsc.numpy()

    matches = flann.knnMatch(scr_dsc, dst_dsc, k=2)

    good = []
    for m, n in matches:
        if m.distance < ratio_threshold * n.distance:
            good.append(m)

    good.sort(key=lambda m: m.distance)

    src_index = [m.queryIdx for m in good]
    dst_index = [m.trainIdx for m in good]

    src_pt = tf.gather(scr_kp.to_image_size().pt, tf.constant(src_index, dtype=tf.int32))
    dst_pt = tf.gather(dst_kp.to_image_size().pt, tf.constant(dst_index, dtype=tf.int32))
    return src_pt, dst_pt


# SIFT

In [None]:
from typing import Union
import tensorflow as tf
from tensorflow.python.keras import backend

# https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf

backend.set_floatx('float32')
linalg_ops = tf.linalg
math_ops = tf.math
image_ops = tf.image


class SIFT:
    def __init__(
            self,
            sigma: float = 1.6,
            assume_blur_sigma: float = 0.5,
            n_intervals: int = 3,
            n_octaves: Union[int, None] = None,
            border_width: int = 5,
            convergence_iter: int = 5
    ):
        self.sigma = sigma
        self.assume_blur_sigma = assume_blur_sigma
        self.n_intervals = n_intervals
        self.n_octaves = n_octaves
        self.border_width = border_width
        self.convergence_N = convergence_iter
        self.octave_pyramid: list[Octave] = []
        self.templet_capture: Union[None, list[KeyPoints, tf.Tensor]] = None

    def __repr__(self):
        return f'S(sigma)={self.sigma}, AssumeBlurS={self.assume_blur_sigma}, ScalesPerOctave={self.n_intervals + 3}, NumOfOctaves={self.n_octaves}'

    def __init_graph(
            self,
            inputs: tf.Tensor
    ) -> tuple[tf.Tensor, list[tf.Tensor]]:
        if not isinstance(inputs, tf.Tensor): raise ValueError('Input image need to be of type Tensor')

        _shape = inputs.get_shape().as_list()
        if len(_shape) != 4 or _shape[-1] != 1:
            raise ValueError('expected the inputs to be grayscale images with size of (None, h, w, 1)')

        inputs = tf.cast(inputs, dtype=tf.float32)
        _, h_, w_, _ = _shape

        kernels = self.__pyramid_kernels()

        min_shape = int(kernels[-1].get_shape()[0])
        s_ = tf.cast(min([h_ * 2, w_ * 2]), dtype=tf.float32)
        diff = math_ops.log(s_)
        if min_shape > 1: diff = diff - math_ops.log(tf.cast(min_shape, dtype=tf.float32))
        max_n_octaves = int(tf.round(diff / math_ops.log(2.0)) + 1)

        if self.n_octaves is not None and max_n_octaves > self.n_octaves: max_n_octaves = self.n_octaves
        self.n_octaves = max_n_octaves
        return inputs, kernels

    def __pyramid_kernels(
            self
    ) -> list[tf.Tensor]:
        delta_sigma = (self.sigma ** 2) - ((2 * self.assume_blur_sigma) ** 2)
        delta_sigma = math_ops.sqrt(tf.maximum(delta_sigma, 0.64))

        base_kernel = gaussian_kernel(kernel_size=0, sigma=delta_sigma)
        base_kernel = tf.expand_dims(tf.expand_dims(base_kernel, axis=-1), axis=-1)

        images_per_octaves = self.n_intervals + 3
        K = 2 ** (1 / self.n_intervals)
        K = tf.cast(K, dtype=tf.float32)

        kernels = [base_kernel]

        for i in range(1, images_per_octaves):
            s_prev = self.sigma * (K ** (i - 1))
            s = math_ops.sqrt((K * s_prev) ** 2 - s_prev ** 2)
            kernel_ = gaussian_kernel(kernel_size=0, sigma=s)
            kernels.append(tf.expand_dims(tf.expand_dims(kernel_, axis=-1), axis=-1))
        return kernels

    def __assign_descriptors(
            self,
            descriptors: tf.Tensor,
            bins: tf.Tensor,
            magnitude: tf.Tensor
    ) -> tf.Tensor:
        N_bins, window_width = 8, 4
        _, y, x, _ = tf.unstack(bins, 4, -1)
        mask = tf.where((y > -1) & (y < window_width) & (x > -1) & (x < window_width), True, False)
        magnitude = tf.boolean_mask(magnitude, mask)

        b, y, x, z = tf.unstack(tf.boolean_mask(bins, mask), 4, -1)

        while tf.reduce_min(z) < 0:
            z = tf.where(z < 0, z + N_bins, z)

        while tf.reduce_max(z) >= N_bins:
            z = tf.where(z >= N_bins, z - N_bins, z)

        bin_floor = [b] + [tf.round(tf.floor(h)) for h in [y, x, z]]
        bin_frac = [tf.reshape(h - hf, (-1,)) for h, hf in zip([y, x, z], bin_floor[1:])]

        y, x, z = bin_frac

        _C0 = magnitude * (1 - y)
        _C1 = magnitude * y

        # interpolation in x direction
        _C00 = _C0 * (1 - x)
        _C01 = _C0 * x

        _C10 = _C1 * (1 - x)
        _C11 = _C1 * x

        # interpolation in z direction
        _C000 = _C00 * (1 - z)
        _C001 = _C00 * z
        _C010 = _C01 * (1 - z)
        _C011 = _C01 * z
        _C100 = _C10 * (1 - z)
        _C101 = _C10 * z
        _C110 = _C11 * (1 - z)
        _C111 = _C11 * z

        b, y, x, z = [tf.cast(c, tf.int32) for c in bin_floor]
        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 1, x + 1, z), -1), _C000)
        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 1, x + 1, (z + 1) % N_bins), -1), _C001)

        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 1, x + 2, z), -1), _C010)
        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 1, x + 2, (z + 1) % N_bins), -1), _C011)

        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 2, x + 1, z), -1), _C100)
        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 2, x + 1, (z + 1) % N_bins), -1), _C101)

        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 2, x + 2, z), -1), _C110)
        descriptors = tf.tensor_scatter_nd_add(descriptors, tf.stack((b, y + 2, x + 2, (z + 1) % N_bins), -1), _C111)

        return descriptors

    def __descriptors_per_octave(
            self,
            octave: Octave,
            key_points: KeyPoints
    ) -> tf.Tensor:
        scale_multiplier, window_width, N_bins, descriptor_max_value = 3, 4, 8, 0.2
        bins_per_degree = N_bins / 360.
        weight_multiplier = -1.0 / (0.5 * window_width * window_width)
        descriptors = tf.zeros((key_points.shape[0], window_width + 2, window_width + 2, N_bins), tf.float32)

        key_points = key_points.to_image_size()
        unpack_oct = key_points.unpack_octave()

        scale_ = tf.pad(tf.repeat(unpack_oct.scale, 2, axis=1), tf.constant([[0, 0], [1, 0]]), constant_values=1.0)
        points = tf.round(tf.concat((key_points.pt * scale_, unpack_oct.layer), -1))
        histogram_width = scale_multiplier * 0.5 * unpack_oct.scale * key_points.size
        radius = tf.round(histogram_width * math_ops.sqrt(2.0) * (window_width + 1.0) * 0.5)

        _, y, x, _ = tf.split(points, [1] * 4, -1)
        radius = math_ops.minimum(
            math_ops.minimum(octave.shape[1] - 3 - y, octave.shape[2] - 3 - x),
            math_ops.minimum(math_ops.minimum(y, x), radius)
        )
        radius = tf.reshape(radius, (-1,))
        parallel = tf.unique(radius)

        indexes = tf.dynamic_partition(tf.reshape(tf.range(key_points.shape[0], dtype=tf.int32), (-1, 1)),
                                       parallel.idx, tf.reduce_max(parallel.idx) + 1)

        wrap = tf.concat((points, key_points.angle, histogram_width), -1)
        wrap = tf.dynamic_partition(wrap, parallel.idx, parallel.y.get_shape()[0])

        M = octave.magnitude
        T = octave.orientation % 360.0

        for index, wrap_i, r in zip(indexes, wrap, parallel.y):
            points, angle, width = tf.split(wrap_i, [4, 1, 1], -1)
            angle = 360.0 - angle
            n = points.get_shape()[0]
            cos = math_ops.cos((PI / 180) * angle)
            sin = math_ops.sin((PI / 180) * angle)

            neighbor = make_neighborhood2D(tf.constant([[0, 0, 0, 0]], dtype=tf.int64), con=(r * 2) + 1)
            block = tf.expand_dims(tf.cast(points, tf.int64), axis=1) + neighbor

            neighbor = tf.cast(tf.repeat(tf.split(neighbor, [1, 2, 1], -1)[1], n, 0), tf.float32)
            y, x = tf.unstack(neighbor, 2, -1)
            b = tf.cast(tf.ones(y.get_shape(), dtype=tf.int32) * index, tf.float32)

            rotate = [(- (x * sin) + (y * cos)) / width, ((x * cos) + (y * sin)) / width]
            weight = tf.reshape(math_ops.exp(weight_multiplier * (rotate[0] ** 2 + rotate[1] ** 2)), (-1,))

            magnitude = tf.gather_nd(M, tf.reshape(block, (-1, 4))) * weight
            orientation = tf.reshape(tf.gather_nd(T, tf.reshape(block, (-1, 4))), (n, -1))
            orientation = ((orientation - angle) * bins_per_degree)

            hist_bin = [b] + [rot + 0.5 * window_width - 0.5 for rot in rotate] + [orientation]
            hist_bin = tf.reshape(tf.stack(hist_bin, -1), (-1, 4))

            descriptors = self.__assign_descriptors(descriptors, hist_bin, magnitude)

        descriptors = tf.slice(descriptors, [0, 1, 1, 0], [key_points.shape[0], window_width, window_width, N_bins])
        descriptors = tf.reshape(descriptors, (key_points.shape[0], -1))

        threshold = tf.norm(descriptors, ord=2, axis=1, keepdims=True) * descriptor_max_value
        threshold = tf.repeat(threshold, N_bins * window_width * window_width, 1)
        descriptors = tf.where(descriptors > threshold, threshold, descriptors)
        descriptors = descriptors / tf.maximum(tf.norm(descriptors, ord=2, axis=1, keepdims=True), 1e-7)
        descriptors = tf.round(descriptors * 512)
        descriptors = tf.maximum(descriptors, 0)
        descriptors = tf.minimum(descriptors, 255)
        return descriptors

    def localize_extrema(
            self,
            octave: Octave
    ) -> KeyPoints:
        if not isinstance(octave, Octave): raise ValueError('octave need to by of type "Octave"')
        dim = octave.shape[-1]
        con, extrema_offset, contrast_threshold, eigen_ration = 3, 0.5, 0.03, 10
        octave_index = octave.index

        """Extract all the extrema point in the octave scale space"""
        # D(batch, y, x, s)
        dog = math_ops.subtract(tf.split(octave.gss, [1, dim - 1], -1)[1],
                                tf.split(octave.gss, [dim - 1, 1], -1)[0])
        dog_shape = dog.get_shape().as_list()

        # e = (batch, y, x, s) (local extrema)
        border_width = self.border_width - 2
        extrema = compute_extrema3D(tf.round(dog), con=con, border_width=[border_width, border_width, 0])

        dog = dog / 255.0

        """Compute the key points conditions for all the image"""
        # DD / Dx
        grad = compute_central_gradient3D(dog)
        grad = tf.expand_dims(grad, -1)

        # D^2D / Dx^2
        hess = compute_hessian_3D(dog)

        # X' = - (D^2D / Dx^2) * (DD / Dx)
        extrema_update = - linalg_ops.lstsq(hess, grad, l2_regularizer=0.0, fast=False)
        extrema_update = tf.squeeze(extrema_update, axis=-1)

        # (DD / Dx) * X'
        dot_ = linalg_ops.matmul(tf.expand_dims(extrema_update, 4), grad)
        dot_ = tf.squeeze(tf.squeeze(dot_, -1), -1)

        mid_cube_values = tf.slice(dog, [0, 1, 1, 1],
                                   [dog_shape[0], dog_shape[1] - 2, dog_shape[2] - 2, dog_shape[3] - 2])

        # D(X') = D + 0.5 * (DD / Dx) * X'
        update_response = mid_cube_values + 0.5 * dot_

        hess_shape = hess.get_shape().as_list()
        # H[[Dxx, Dxy], [Dyx, Dyy]]
        hess_xy = tf.slice(hess, [0, 0, 0, 0, 0, 0], [*hess_shape[:-2], 2, 2])
        # Dxx + Dyy
        hess_xy_trace = linalg_ops.trace(hess_xy)
        # Dxx * Dyy - Dxy * Dyx
        hess_xy_det = linalg_ops.det(hess_xy)

        # |X'| <= 0.5
        # (X' is larger than 0.5 in any dimension, means that the extreme lies closer to a different sample point)
        kp_cond1 = math_ops.less_equal(math_ops.reduce_max(math_ops.abs(extrema_update), axis=-1), extrema_offset)

        # |D(X')| >= 0.03 (threshold on minimum contrast)
        kp_cond2 = math_ops.greater_equal(math_ops.abs(update_response), contrast_threshold)

        # (Dxx + Dyy) ^ 2 / Dxx * Dyy - Dxy * Dyx < (r + 1) ^ 2 / r
        # ---> ((Dxx + Dyy) ^ 2) * r < (Dxx * Dyy - Dxy * Dyx) * ((r + 1) ^ 2)
        # (threshold on ratio of principal curvatures)
        kp_cond3 = math_ops.logical_and(
            eigen_ration * (hess_xy_trace ** 2) < ((eigen_ration + 1) ** 2) * hess_xy_det, hess_xy_det != 0
        )
        cond = tf.where(kp_cond1 & kp_cond2 & kp_cond3, True, False)

        kp_cond4 = tf.scatter_nd(extrema, tf.ones((extrema.shape[0],), dtype=tf.bool), dog_shape)
        kp_cond4 = tf.slice(kp_cond4, [0, 1, 1, 1],
                            [dog_shape[0], dog_shape[1] - 2, dog_shape[2] - 2, dog_shape[3] - 2])

        """Localize the extrema points"""
        sure_key_points = math_ops.logical_and(cond, kp_cond4)
        attempts = math_ops.logical_and(kp_cond4, ~sure_key_points)

        shape_ = sure_key_points.get_shape().as_list()

        for _ in range(self.convergence_N):
            attempts_cords = tf.where(attempts)
            if attempts_cords.shape[0] == 0: break
            # if ist only one point the shape will bw (4, )
            attempts_cords = tf.reshape(attempts_cords, (-1, 4))
            attempts_update = tf.gather_nd(extrema_update, attempts_cords)

            ex, ey, ez = tf.unstack(attempts_update, num=3, axis=-1)
            cd, cy, cx, cz = tf.unstack(tf.cast(attempts_cords, tf.float32), num=4, axis=1)
            attempts_next = [cd, cy + ey, cx + ex, cz + ez]

            # check that the new cords will lie within the image shape
            cond_next = tf.where(
                (attempts_next[1] >= 0) & (attempts_next[1] < shape_[1]) & (attempts_next[2] > 0) & (
                        attempts_next[2] < shape_[2]) & (attempts_next[3] > 0) & (
                        attempts_next[3] < shape_[3]))

            attempts_next = tf.stack(attempts_next, -1)
            attempts_next = tf.cast(tf.gather(attempts_next, tf.squeeze(cond_next)), dtype=tf.int64)
            if attempts_next.shape[0] == 0: break
            attempts_next = tf.reshape(attempts_next, (-1, 4))

            attempts_mask = tf.scatter_nd(attempts_next, tf.ones((attempts_next.shape[0],), dtype=tf.bool), shape_)

            # add new key points
            new_cords = tf.where(attempts_mask & ~sure_key_points & cond)
            sure_key_points = tf.tensor_scatter_nd_update(sure_key_points, new_cords,
                                                          tf.ones((new_cords.shape[0],), dtype=tf.bool))
            # next points
            attempts = math_ops.logical_and(attempts_mask, ~sure_key_points)

        """Construct the key points"""
        cords = tf.where(sure_key_points)
        if cords.shape[0] == 0: return KeyPoints()
        kp_cords = cords + tf.constant([[0, 1, 1, 1]], dtype=tf.int64)

        # X' = - (D^2D / Dx^2) * (DD / Dx)
        extrema_update = tf.gather_nd(extrema_update, cords)
        octave_index = tf.cast(octave_index, dtype=tf.float32)

        # x', y', s'
        ex, ey, ez = tf.unstack(extrema_update, num=3, axis=1)

        # batch, y, x, s
        cd, cy, cx, cz = tf.unstack(tf.cast(kp_cords, tf.float32), num=4, axis=1)

        # pt = (batch, y = (y + y') * (1 << octave), (x + x') * (1 << octave), s) points in size of octave 0
        kp_pt = tf.stack(
            (cd, (cy + ey) * (2 ** octave_index), (cx + ex) * (2 ** octave_index), cz), axis=-1
        )
        # octave = octave_index + s * (1 << 8) + round((s' + 0.5) * 255) * (1 << 16)
        kp_octave = octave_index + cz * (2 ** 8) + tf.round((ez + 0.5) * 255.0) * (2 ** 16)

        # size = (sigma << ((s + s') / sn)) << (octave_index + 1)
        kp_size = self.sigma * (2 ** ((cz + ez) / (dim - 3))) * (2 ** (octave_index + 1.0))

        # D(X') = D + 0.5 * (DD / Dx) * X'
        kp_response = math_ops.abs(tf.gather_nd(update_response, cords))

        key_points = KeyPoints(
            pt=tf.reshape(kp_pt, (-1, 4)),
            size=tf.reshape(kp_size, (-1, 1)),
            angle=tf.reshape(tf.ones_like(kp_size) * -1.0, (-1, 1)),
            octave=tf.reshape(kp_octave, (-1, 1)),
            response=tf.reshape(kp_response, (-1, 1))
        )
        return key_points

    def orientation_assignment(
            self,
            octave: Octave,
            key_points: KeyPoints
    ) -> KeyPoints:
        if not isinstance(octave, Octave): raise ValueError('octave need to by of type "Octave"')
        if not isinstance(key_points, KeyPoints): raise ValueError('key_points need to by of type "KeyPoints"')

        orientation_N_bins, scale_factor, radius_factor = 36, 1.5, 3
        histogram = tf.zeros((key_points.shape[0], orientation_N_bins), dtype=tf.float32)

        # scale = 1.5 * sigma  * (1 << ((s + s') / sn)
        scale = scale_factor * key_points.size / (2 ** (octave.index + 1))

        # r[N_points, ] = 3 * scale
        radius = tf.cast(tf.round(radius_factor * scale), dtype=tf.int64)

        # wf[N_points, ]
        weight_factor = -0.5 / (scale ** 2)

        # points back to octave resolution
        _prob = 1.0 / (1 << octave.index)
        _prob = tf.stack((tf.ones_like(_prob), _prob, _prob), axis=-1)
        _prob = tf.squeeze(_prob)

        # [batch, x + x', y + y', s] * N_points
        region_center = tf.cast(key_points.pt * _prob, dtype=tf.int64)
        region_center = tf.concat((region_center, tf.cast(key_points.scale_index, dtype=tf.int64)), -1)

        # check that the radius in the image size
        _, y, x, _ = tf.split(region_center, [1] * 4, -1)
        radius = math_ops.minimum(
            math_ops.minimum(octave.shape[1] - 3 - y, octave.shape[2] - 3 - x),
            math_ops.minimum(math_ops.minimum(y, x), radius)
        )
        radius = tf.reshape(radius, (-1,))

        # parallel computation
        parallel = tf.unique(radius)
        split_region = tf.dynamic_partition(
            tf.concat((tf.cast(region_center, tf.float32), weight_factor), -1), parallel.idx,
            tf.reduce_max(parallel.idx) + 1
        )
        index = tf.dynamic_partition(tf.reshape(tf.range(key_points.shape[0], dtype=tf.int64), (-1, 1)),
                                     parallel.idx, tf.reduce_max(parallel.idx) + 1)

        M = octave.magnitude
        T = octave.orientation

        for region_weight, r, hist_index in zip(split_region, parallel.y, index):
            region, weight = tf.split(region_weight, [4, 1], -1)
            if r < 1: continue

            neighbor = make_neighborhood2D(tf.constant([[0, 0, 0, 0]], dtype=tf.int64), con=(r * 2) + 1)
            block = tf.expand_dims(tf.cast(region, tf.int64), axis=1) + neighbor

            magnitude = tf.gather_nd(M, tf.reshape(block, (-1, 4)))
            orientation = tf.gather_nd(T, tf.reshape(block, (-1, 4)))

            _, curr_y, curr_x, _ = tf.unstack(tf.cast(neighbor, dtype=tf.float32), 4, axis=-1)
            weight = tf.reshape(math_ops.exp(weight * (curr_y ** 2 + curr_x ** 2)), (-1,))

            hist_deg = tf.cast(tf.round(orientation * orientation_N_bins / 360.), dtype=tf.int64) % orientation_N_bins

            hist_index = tf.ones(block.get_shape()[:-1], dtype=tf.int64) * tf.reshape(hist_index, (-1, 1))
            hist_index = tf.stack((tf.reshape(hist_index, (-1,)), hist_deg), -1)
            histogram = tf.tensor_scatter_nd_add(histogram, hist_index, weight * magnitude)

        """ find peaks in the histogram """
        # histogram smooth
        gaussian1D = tf.constant([1, 4, 6, 4, 1], dtype=tf.float32) / 16.0
        gaussian1D = tf.reshape(gaussian1D, shape=(-1, 1, 1))

        pad_ = tf.split(tf.expand_dims(histogram, axis=-1), [2, orientation_N_bins - 4, 2], 1)
        pad_ = tf.concat([pad_[-1], *pad_, pad_[0]], 1)

        smooth_histogram = tf.nn.convolution(pad_, gaussian1D, padding='VALID')
        smooth_histogram = tf.squeeze(smooth_histogram, axis=-1)

        orientation_max = tf.reduce_max(smooth_histogram, axis=-1)

        peak = tf.nn.max_pool1d(tf.expand_dims(smooth_histogram, -1), ksize=3, padding="SAME", strides=1)
        peak = tf.squeeze(peak, -1)

        value_cond = tf.repeat(tf.reshape(orientation_max, shape=(-1, 1)), repeats=36, axis=-1) * 0.8

        peak = tf.where((peak == smooth_histogram) & (smooth_histogram > value_cond))

        p_idx, p_deg = tf.unstack(peak, num=2, axis=-1)

        # interpolate the peak position - parabola
        kernel = tf.constant([1., 0, -1.], shape=(3, 1, 1))
        kernel = tf.concat((kernel, tf.constant([1., -2., 1.], shape=(3, 1, 1))), -1)

        pad_ = tf.split(smooth_histogram, [1, 34, 1], -1)
        pad_ = tf.concat([pad_[-1], *pad_, pad_[0]], -1)

        interp = tf.unstack(tf.nn.convolution(tf.expand_dims(pad_, -1), kernel, padding="VALID"), 2, -1)
        interp = 0.5 * (interp[0] / interp[1]) % 36
        interp = tf.cast(p_deg, tf.float32) + tf.gather_nd(interp, peak)

        orientation = 360. - interp * 360. / 36

        orientation = tf.where(math_ops.abs(orientation - 360.) < 1e-7, 0.0, orientation)

        wrap = key_points.as_array()
        wrap = tf.gather(wrap, p_idx)
        pt, size, _, oc, response = tf.split(wrap, [4, 1, 1, 1, 1], axis=-1)
        key_points.from_array(tf.concat((pt, size, tf.reshape(orientation, (-1, 1)), oc, response), axis=-1),
                              inplace=True)
        key_points.relies_scale_index()
        return key_points

    def write_descriptors(
            self,
            key_points: KeyPoints
    ) -> tf.Tensor:
        if not isinstance(key_points, KeyPoints): raise ValueError('key_points need to by of type "KeyPoints"')
        unpack_oct = key_points.unpack_octave()
        parallel = tf.unique(tf.squeeze(unpack_oct.octave))

        if parallel.y.get_shape()[0] == 1:
            return self.__descriptors_per_octave(self.octave_pyramid[int(parallel.y)], key_points)

        indexes = tf.dynamic_partition(tf.reshape(tf.range(key_points.shape[0], dtype=tf.int32), (-1, 1)),
                                       parallel.idx, tf.reduce_max(parallel.idx) + 1)

        split_by_oc = key_points.partition_by_index(parallel.idx)

        condition_indices = []
        partitioned_data = []

        for keys, index, oc_id in zip(split_by_oc, indexes, parallel.y):
            oc_desc = self.__descriptors_per_octave(self.octave_pyramid[int(oc_id)], keys)
            condition_indices.append(tf.squeeze(index, -1))
            partitioned_data.append(oc_desc)

        descriptors = tf.dynamic_stitch(condition_indices, partitioned_data)
        return descriptors

    def build_pyramid(
            self,
            I: tf.Tensor
    ):
        def conv_with_pad(x: tf.Tensor, h: tf.Tensor) -> tf.Tensor:
            k_ = h.get_shape()[0] // 2
            x = tf.pad(x, tf.constant([[0, 0], [k_, k_], [k_, k_], [0, 0]], tf.int32), 'SYMMETRIC')
            return tf.nn.convolution(x, h, padding='VALID')

        I, kernels = self.__init_graph(I)
        self.octave_pyramid = []
        _, h_, w_, _ = I.get_shape()

        I = image_ops.resize(I, size=[h_ * 2, w_ * 2], method='bilinear')
        I = conv_with_pad(I, kernels[0])

        size_ = [h_, w_]

        for oc_id in range(self.n_octaves):
            oc_cap = [I]
            for kernel in kernels[1:]:
                I = conv_with_pad(I, kernel)
                oc_cap.append(I)
            if oc_id < self.n_octaves - 1:
                I = image_ops.resize(oc_cap[-3], size=size_, method='nearest')
                size_ = [size_[0] // 2, size_[1] // 2]

            gss = tf.concat(oc_cap, -1)
            oc = Octave(oc_id, gss)
            self.octave_pyramid.append(oc)

    def keypoints_with_descriptors(
            self,
            inputs: tf.Tensor,
            keep_as_templet: bool = False
    ) -> tuple[KeyPoints, tf.Tensor]:
        if keep_as_templet and self.templet_capture: raise Warning('prev templet will be removed')
        self.build_pyramid(inputs)
        key_points = KeyPoints()
        key_points.relies_scale_index()

        for oc in self.octave_pyramid:
            oc_kp = self.localize_extrema(oc)
            if oc_kp.shape[0] == 0: continue
            oc_kp = self.orientation_assignment(oc, oc_kp)
            key_points += oc_kp

        descriptors = self.write_descriptors(key_points)
        self.octave_pyramid = []
        self.n_octaves = None
        if keep_as_templet:
            self.templet_capture = [key_points, descriptors]
        return key_points, descriptors

    def relies_templet(
            self
    ):
        self.templet_capture = None

# Plot

In [None]:
from matplotlib import pyplot as plt
import matplotlib
import tensorflow as tf
import numpy as np
from typing import Union
import cv2

# matplotlib.use("Qt5Agg")


def show(
        image: Union[np.ndarray, tf.Tensor],
        ax: plt.Axes
):
    if tf.is_tensor(image):
        image = image.numpy()
    assert isinstance(image, np.ndarray)
    if image.max() > 1:
        image = image.astype('uint8')
    if len(image.shape) == 2 or image.shape[-1] == 1:
        ax.imshow(image, cmap='gray')
    else:
        ax.imshow(image)


def show_images(
        images: Union[np.ndarray, tf.Tensor],
        subplot_y: Union[None, int],
        subplot_x: Union[None, int]
):
    assert isinstance(images, (np.ndarray, tf.Tensor, list, tuple))
    subplot_x = min([len(images), 4]) if subplot_x is None else subplot_x
    subplot_y = len(images) // subplot_x if subplot_y is None else subplot_y

    fig, _ = plt.subplots(subplot_x, subplot_y, subplot_kw={'xticks': [], 'yticks': []})
    fig.subplots_adjust(wspace=0, hspace=0.05)

    for i in range(min([subplot_x * subplot_y, len(images)])):
        show(images[i], fig.axes[i])


def show_key_points(
        key_points: KeyPoints,
        img: tf.Tensor
):
    if not isinstance(key_points, KeyPoints): raise TypeError('Key points need to be of type "KeyPoints"')
    key_points = key_points.to_image_size()

    if key_points.n_batches() > 1: raise ValueError("number of batches in the key points > 1")
    if not isinstance(img, tf.Tensor): raise TypeError('image need to be of type "Tensor"')

    shape = img.get_shape().as_list()

    if shape[0] > 1: raise ValueError("number of batches in the image > 1")
    if not (shape[-1] == 1 or shape[-1] == 3):
        raise ValueError("image need to be with 3 channels (RGB) or gray level with one channel")

    points = tf.concat((key_points.pt, tf.zeros((key_points.shape[0], 1), tf.float32)), -1)
    points = tf.cast(points, tf.int32)

    cross = [
        [0, -2, 0, 0],
        [0, -1, 0, 0],
        [0, 0, -2, 0],
        [0, 0, -1, 0],
        [0, 0, 1, 0],
        [0, 0, 2, 0],
        [0, 1, 0, 0],
        [0, 2, 0, 0]
    ]

    cross = tf.constant(cross, shape=(1, 8, 4), dtype=tf.int32)
    neighbor = cross + tf.expand_dims(points, 1)
    neighbor = tf.reshape(neighbor, (-1, 4))

    _, y, x, _ = tf.unstack(neighbor, 4, -1)
    mask = tf.where((y > 0) & (y < shape[1]) & (x > 0) & (x < shape[2]), True, False)
    neighbor = tf.boolean_mask(neighbor, mask)

    kpt_image = tf.zeros([*shape[:-1], 1], dtype=img.dtype)
    kpt_image = tf.tensor_scatter_nd_add(kpt_image, neighbor, tf.ones((neighbor.get_shape()[0],), tf.float32))
    kpt_image = tf.where(kpt_image > 0, 1, 0)
    kpt_image = tf.concat((kpt_image * 242, kpt_image * 140, kpt_image * 40), -1)

    if shape[-1] == 1:
        img_del = tf.tensor_scatter_nd_update(img, neighbor, tf.zeros((neighbor.get_shape()[0],), tf.float32))
        img_del = tf.cast(tf.repeat(img_del, 3, -1), tf.int32)
    else:
        img_del = img
        for _ in range(3):
            img_del = tf.tensor_scatter_nd_update(img_del, neighbor, tf.zeros((neighbor.get_shape()[0],), tf.float32))
            neighbor = neighbor + tf.constant([0, 0, 0, 1], dtype=tf.int32)

    mark = kpt_image + img_del
    show_images(tf.cast(mark, tf.uint8), 1, 1)


def _make_line(
        x0: Union[tf.Tensor, int, float],
        x1: Union[tf.Tensor, int, float],
        y0: Union[tf.Tensor, int, float],
        y1: Union[tf.Tensor, int, float],
        h_limit: Union[tf.Tensor, int],
        w_limit: Union[tf.Tensor, int]
):
    m = (y1 - y0) / (x1 - x0)
    x = tf.range(x0, x1 + 1, dtype=tf.float32)
    w = tf.sqrt(1 + tf.math.abs(m)) / 2
    y = x * m + (x1 * y0 - x0 * y1) / (x1 - x0)

    t = tf.math.ceil(w / 2)

    yy = (tf.reshape(tf.math.floor(y), [-1, 1]) + tf.reshape(tf.range(-t - 1, t + 2, dtype=tf.float32), [1, -1]))
    xx = tf.repeat(x, yy.get_shape()[1])

    v = tf.clip_by_value(
        tf.minimum(yy + 1 + 1 / 2 - tf.reshape(y, (-1, 1)), -yy + 1 + 1 / 2 + tf.reshape(y, (-1, 1))), 0, 1
    )
    v = tf.reshape(v, (-1,))
    yy = tf.reshape(yy, (-1,))

    limits = tf.where((yy >= 0) & (xx >= 0) & (yy < h_limit) & (xx < w_limit) & (v == 1.), True, False)

    xx = tf.boolean_mask(xx, limits)
    yy = tf.boolean_mask(yy, limits)
    v = tf.boolean_mask(v, limits)

    cords = tf.cast(tf.stack((yy, xx), -1), tf.int32)
    cords = tf.pad(cords, [[0, 0], [1, 1]])
    return cords, v


def plot_matches_TF(
        scr_img: tf.Tensor,
        dst_img: tf.Tensor,
        src_pt: tf.Tensor,
        dst_pt: tf.Tensor
) -> tf.Tensor:
    if not isinstance(scr_img, tf.Tensor) or not isinstance(dst_img, tf.Tensor):
        raise TypeError('descriptors need to be of type "Tensor"')
    if not isinstance(src_pt, tf.Tensor) or not isinstance(dst_pt, tf.Tensor):
        raise TypeError('points need to be of type "Tensor"')
    if src_pt.get_shape()[0] != dst_pt.get_shape()[0]:
        raise ValueError('points need to be with the same size')

    _, h_scr, w_scr, c_scr = scr_img.get_shape().as_list()
    _, h_dst, w_dst, c_dst = dst_img.get_shape().as_list()
    if not (c_scr == 1 or c_scr == 3) or not (c_dst == 1 or c_dst == 3):
        raise ValueError("images need to be with 3 channels (RGB) or gray level with one channel")

    if c_scr == 1:
        scr_img = tf.repeat(scr_img, 3, -1)
    if c_dst == 1:
        dst_img = tf.repeat(dst_img, 3, -1)

    h_new = max(h_scr, h_dst)
    w_new = w_scr + w_dst

    h_diff = h_new - min(h_scr, h_dst)
    h_up = h_diff // 2
    h_down = h_diff - h_up

    if h_scr < h_dst:
        scr_img = tf.pad(scr_img, [[0, 0], [h_up, h_down], [0, 0], [0, 0]])
    elif h_scr > h_dst:
        dst_img = tf.pad(dst_img, [[0, 0], [h_up, h_down], [0, 0], [0, 0]])

    marked_image = tf.concat((scr_img, dst_img), 2)

    src_b, src_y, src_x = tf.unstack(src_pt, 3, -1)
    dst_b, dst_y, dst_x = tf.unstack(dst_pt, 3, -1)

    if h_scr < h_dst:
        src_y = src_y + h_up
    elif h_scr > h_dst:
        dst_y = dst_y + h_up

    dst_x = dst_x + w_scr
    lines = tf.zeros([*marked_image.get_shape()[:-1], 1])

    for y1, x1, y2, x2 in zip(src_y, src_x, dst_y, dst_x):
        c, val = _make_line(x1, x2, y1, y2, h_new, w_new)
        lines = tf.tensor_scatter_nd_update(lines, c, val)

    lines = tf.concat((lines * 9, lines * 121, lines * 105), -1)

    marked_image = tf.where(lines > 0, lines, marked_image)
    show_images(tf.cast(marked_image, tf.uint8), 1, 1)
    return marked_image


def plot_matches_CV2(
        scr_img: tf.Tensor,
        dst_img: tf.Tensor,
        src_pt: tf.Tensor,
        dst_pt: tf.Tensor
) -> tf.Tensor:
    if not isinstance(scr_img, tf.Tensor) or not isinstance(dst_img, tf.Tensor):
        raise TypeError('descriptors need to be of type "Tensor"')
    if not isinstance(src_pt, tf.Tensor) or not isinstance(dst_pt, tf.Tensor):
        raise TypeError('points need to be of type "Tensor"')
    if src_pt.get_shape()[0] != dst_pt.get_shape()[0]:
        raise ValueError('points need to be with the same size')
    _, h_scr, w_scr, c_scr = scr_img.get_shape().as_list()
    _, h_dst, w_dst, c_dst = dst_img.get_shape().as_list()

    if not (c_scr == 1 or c_scr == 3) or not (c_dst == 1 or c_dst == 3):
        raise ValueError("images need to be with 3 channels (RGB) or gray level with one channel")

    if c_scr == 1:
        scr_img = tf.repeat(scr_img, 3, -1)
    if c_dst == 1:
        dst_img = tf.repeat(dst_img, 3, -1)

    h_new = max(h_scr, h_dst)
    h_diff = h_new - min(h_scr, h_dst)
    h_up = h_diff // 2
    h_down = h_diff - h_up

    if h_scr < h_dst:
        scr_img = tf.pad(scr_img, [[0, 0], [h_up, h_down], [0, 0], [0, 0]])
    elif h_scr > h_dst:
        dst_img = tf.pad(dst_img, [[0, 0], [h_up, h_down], [0, 0], [0, 0]])

    marked_image = tf.concat((scr_img, dst_img), 2)
    marked_image = tf.squeeze(marked_image).numpy().astype('uint8')

    _, src_y, src_x = tf.unstack(src_pt, 3, -1)
    _, dst_y, dst_x = tf.unstack(dst_pt, 3, -1)

    if h_scr < h_dst:
        src_y = src_y + h_up
    elif h_scr > h_dst:
        dst_y = dst_y + h_up

    dst_x = dst_x + w_scr
    src_pt = tf.stack((src_y, src_x), -1).numpy().astype(int)
    dst_pt = tf.stack((dst_y, dst_x), -1).numpy().astype(int)

    for i in range(src_pt.shape[0]):
        pt1 = (int(src_pt[i, 1]), int(src_pt[i, 0]))
        pt2 = (int(dst_pt[i, 1]), int(dst_pt[i, 0]))
        cv2.line(marked_image, pt1, pt2, (9, 121, 105))

    show_images([marked_image], 1, 1)
    return tf.constant(marked_image, shape=(1, marked_image.shape[0], marked_image.shape[1], 3))


# Main

In [None]:
import sqlite3
import pandas as pd
import cv2
import matplotlib.pyplot as plt
# Connect to your database
db_path = '/content/drive/MyDrive/Aerial-Template-Matching/data/Google_17.db'
conn = sqlite3.connect(db_path)

# Read the entire 'tiles' table into a pandas DataFrame
df = pd.read_sql_query("SELECT * FROM tiles", conn)

In [None]:
google_map = cv2.imread("/content/drive/MyDrive/Aerial-Template-Matching/data/samples/Bing_17_sample.jpg")
plt.imshow(google_map)

In [None]:
# feature extract refrence image 
image2 = load_image("/content/drive/MyDrive/Aerial-Template-Matching/data/samples/Bing_17_sample.jpg")
image2 = tf.image.resize(image2, [384, 512])  # Use TensorFlow resize
alg = SIFT()
kp2, desc2 = alg.keypoints_with_descriptors(image2)

In [None]:
import pickle


# load feature extract refrence image 
with open("keypoints_descriptors.pkl", "rb") as f:
    kp2, desc2 = pickle.load(f)

In [None]:
# show predict
image1 = load_image(df["Address"][136].replace("\\", "/"))
kp1, desc1 = alg.keypoints_with_descriptors(image1)
show_key_points(kp1, image1)

show_key_points(kp2, image2)

src_pt, dst_pt = templet_matching_CV2(kp1, kp2, desc1, desc2, ratio_threshold=0.7)
out = plot_matches_CV2(image1, image2, src_pt, dst_pt)

# label gt show

In [None]:
x, y  = df.iloc[136][6:8]
output_image = None
# Assuming min_x and min_y are known or predefined values for the tile grid
min_x, min_y = df["x"].min(), df["y"].min()
tile_size = 256
pos_x = (x - min_x) * tile_size
pos_y = (y - min_y) * tile_size

outpot = cv2.rectangle(google_map, (pos_x, pos_y),
                             (pos_x + tile_size, pos_y + tile_size),
                             (255, 255, 0), 7)

plt.imshow(outpot, cmap="gray")