In [None]:
import numpy as np
import scipy as sp
from skimage import io
from matplotlib import pyplot as plt
from skimage import transform
from tqdm import tqdm
from time import time
import pickle
%matplotlib inline

from utils import shift, iou
from registration import (
    set_integration_intervals,
    laguerre_zeros_precompute, 
    image_fbt_precompute, 
    fbm_registration, apply_transform
)

In [None]:
im1 = io.imread("im1.png")
im1 = im1/255.

In [None]:
from matrix_utils import *
import cv2

def shift(im, vec):
    mat_trans = get_translation_mat(*vec)
    return cv2.warpAffine(im, get_mat_2x3(mat_trans), 
                          (im.shape[1], im.shape[0]))
    
    
def rotate(im, angle, center=None):
    h, w = im.shape[:2]
    if center is None:
        center = np.array([w // 2, h // 2])
    mat_trans_minus_center = get_translation_mat(-center[0], -center[1])
    mat_rot = get_rotation_mat(angle, radians=False)
    mat_trans_center = get_translation_mat(*center)
    return cv2.warpAffine(im, get_mat_2x3(mat_trans_center @ mat_rot @ mat_trans_minus_center),
                          (im.shape[1], im.shape[0]))

In [None]:
init_ang=-31
im2 = rotate(im1, init_ang)
im2 = shift(im2, (2, 3))
io.imsave('im2.png', np.uint8(im2*255))

# im3 = shift(im2, (-4, -3))
# im3 = rotate(im3, 360-31)

# mask2 = transform.rotate(mask1, init_ang)
# mask2 = shift(mask2, (4, 3))


# im2 = transform.shift(im1, [init_ang])
plt.imshow(np.stack([im1, im2, im2*0], -1))
# plt.xlim(50, 100)
# plt.ylim(100,150)
plt.show()

In [None]:
# bandwidth = 128
pixel_sampling = 0.5
com_offset = 20.
# lag_func_num=60
# lag_scale=3
image_radius = 128

In [None]:
Im1, Ih1, Imm, theta_net, u_net, x_net, omega_net, psi_net, eta_net, eps, b, bandwidth = \
            set_integration_intervals(image_radius, pixel_sampling, com_offset)
alphas = []
for it_m1 in tqdm(range(len(Im1))):
    m1 = Im1[it_m1]
    for it_h1 in range(len(Ih1)):
        h1 = Ih1[it_h1]
        for it_mm in range(len(Imm)):
            mm = Imm[it_mm]
            if abs(m1 + h1 + mm) in alphas:
                continue
            alphas.append(abs(m1 + h1 + mm))

# image_radius = 2 * bandwidth * pixel_sampling / np.pi
# print(image_radius)
print(theta_net.shape)
print(u_net.shape)
print(x_net.shape)

In [None]:
mask1 = im1# > 0
mask2 = im2# > 0

In [None]:
from skimage.measure import label, regionprops
# mask1 = im1 > 0
props = regionprops(label(mask1))
c1y, c1x = props[0].centroid
center = c1x, c1y
print(center)

In [None]:
start = time()
reg1 = fbm_registration(im1, im2, image_radius=image_radius, p_s=pixel_sampling, com_offset=com_offset,
                       method='fbm', masks=[mask1>0, mask2>0], shift_by_mask=False)
end = time()

im_reg1 = apply_transform(im2, reg1, center)
print('IoU:', iou(im1, im_reg1))
print('Time:', end - start, 'secs')

In [None]:
plt.figure()
plt.imshow(np.stack([im1, im2, im2*0], -1))

plt.figure()
plt.imshow(np.stack([im1, im_reg1, im2*0], -1))
plt.xlim(50, 200)
plt.ylim(200,50)


In [None]:
reg1

In [None]:
reg1

<div>
<img src="illustration.jpeg" width="500"/>
</div>

In [None]:
from utils import normalize_alpha


In [None]:
transform_dict = reg1
h, w = im2.shape[:2]
# if center is None:
center = np.array([w // 2, h // 2])
image = im2.copy()
psi = transform_dict['psi']
etta_prime = transform_dict['etta']
omegga_prime = transform_dict['omegga']
etta = etta_prime - psi
omegga = omegga_prime - etta_prime
eps = transform_dict['eps']
com_offset = transform_dict['com_offset']

x = com_offset * np.cos(psi) + com_offset * np.cos(psi + etta + eps)
y = com_offset * np.sin(psi) + com_offset * np.sin(psi + etta + eps)
rho = (2*com_offset**2 + 2*com_offset**2 * np.cos(psi + etta + eps))**0.5
cos_alpha = 3/rho
print('rho', rho, cos_alpha, np.degrees(np.arccos(cos_alpha)), rho*np.sin(np.arccos(cos_alpha)))
# x, y = 4., 3.
print(x, y)
alpha = omegga_prime + eps

print(np.degrees(alpha), np.degrees(normalize_alpha(alpha)))

mat_trans = get_translation_mat(-x, -y)

im_shifted = cv2.warpAffine(
    im2, get_mat_2x3(mat_trans),
    (im2.shape[1], im2.shape[0])
)
plt.figure()
plt.imshow(np.stack([im1, im_shifted, im2*0], -1))

mat_trans_minus_center = get_translation_mat(-center[0], -center[1])
mat_rot = get_rotation_mat(-normalize_alpha(alpha), radians=True)
mat_trans_center = get_translation_mat(center[0], center[1])
im_rotated = cv2.warpAffine(
    im_shifted, get_mat_2x3(mat_trans_center @ mat_rot @ mat_trans_minus_center),
    (im2.shape[1], im2.shape[0])
)
plt.figure()
plt.imshow(np.stack([im1, im_rotated, im2*0], -1))
print('IoU:', iou(im1, im_rotated))

In [None]:
transform_dict = reg1
h, w = im2.shape[:2]
# if center is None:
center = np.array([w // 2, h // 2])
image = im2.copy()
psi = transform_dict['psi']
etta_prime = transform_dict['etta']
omegga_prime = transform_dict['omegga']
etta = etta_prime - psi
omegga = omegga_prime - etta_prime
eps = transform_dict['eps']
b = transform_dict['com_offset']
print(np.degrees(psi), np.degrees(etta), np.degrees(omegga), np.degrees(eps))

print(np.degrees(omegga_prime + eps))
im_f1 = rotate(im2, np.degrees(omegga))
plt.figure()
plt.imshow(np.stack([im1, im_f1, im2*0], -1))
im_f2 = shift(im_f1, (b, 0))
plt.figure()
plt.imshow(np.stack([im1, im_f2, im2*0], -1))
im_f3 = rotate(im_f2, np.degrees(etta))
plt.figure()
plt.imshow(np.stack([im1, im_f3, im2*0], -1))
im_f4 = shift(im_f3, (b, 0))
plt.figure()
plt.imshow(np.stack([im1, im_f4, im2*0], -1))
im_f5 = rotate(im_f4, np.degrees(psi))
plt.figure()
plt.imshow(np.stack([im1, im_f5, im2*0], -1))
im_f6 = im_f5[::-1, ::-1]
plt.figure()
plt.imshow(np.stack([im1, im_f6, im2*0], -1))
plt.title(f'iou: {iou(im1, im_f6)}')