In [1]:
from functools import lru_cache
import torch, torchvision
import matplotlib.pyplot as plt
import numpy as np
import math
class PhaseCorrelation(torch.nn.Module):
    def __init__(self, shift=True):
        super(PhaseCorrelation, self).__init__()
        self.shift=shift

    def forward(self, im, template):
        imFft = torch.fft.rfft2(im)
        templayteFft = torch.fft.rfft2(template)
        out = torch.fft.irfft2((imFft*templayteFft.conj())/(imFft*templayteFft).abs())
        if not self.shift:
            return out
        return torch.fft.ifftshift(out)
    
class HighPassFilter(torch.nn.Module):
    def __init__(self):
        super().__init__()
    @lru_cache(maxsize=1)
    def get_kernel(self, shape):
        ker = torch.outer(torch.blackman_window(shape[0]), torch.blackman_window(shape[1]))
        return torch.abs(ker - torch.max(ker)).unsqueeze(0).unsqueeze(0)
    def forward(self, img):
        return img.abs() * self.get_kernel(img.shape[-2:])
    
    
class RigidTransform(torch.nn.Module):
    def __init__(self,scale_alpha, trans_x, trans_y,  rot_beta):
        super().__init__()
        self.scale_alpha = scale_alpha
        self.trans_x = trans_x
        self.trans_y = trans_y
        self.rot_beta = rot_beta
    
    @staticmethod
    def rotmat_2d(rot : torch.Tensor) -> torch.Tensor:
        cos_theta = torch.cos(rot/180 * torch.pi)
        sin_theta = torch.sin(rot/180 * torch.pi)
        R = torch.stack([
            torch.stack([cos_theta, sin_theta], dim=-1),
            torch.stack([-sin_theta,  cos_theta], dim=-1)
        ], dim=-2)  
        return R
    def forward(self, image ):
        # Apply a rigid (geometric) transform: scale, translation, rotation via bilinear interpolation
        N,_,H,W = image.shape
        transfMat = torch.cat([
            torch.cat([
                self.scale_alpha.unsqueeze(-1).unsqueeze(-1)*self.rotmat_2d(self.rot_beta),
                torch.stack([self.trans_x/(W//2) , self.trans_y/(H//2)], dim=-1).unsqueeze(-1)
            ],dim=-1),
            torch.cat([
                torch.zeros(self.scale_alpha.size(0),1,2),
                torch.ones(self.scale_alpha.size(0),1,1)
            ], dim=-1)
        ], dim=-2)
        coords = torch.stack(torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W),indexing='xy') ).unsqueeze(0)
        coordsHomo = torch.cat([coords.flatten(-2),torch.ones((coords.size(0),1,H*W))], dim=-2)
        coordsTransf = torch.einsum("bcd,bdl->bcl", transfMat,coordsHomo)[...,:2,:].reshape(N,2, H, W).moveaxis(1,-1)
        return torch.nn.functional.grid_sample(
            image,
            coordsTransf,
            mode='bilinear',
            align_corners=True,
        )
        
class LogPolarRepresentation(torch.nn.Module):
    def __init__(self, H,W):
        super().__init__()
        self.H = H
        self.W = W
        
    def get_radius(self):
        return (torch.norm(torch.tensor([self.H,self.W], dtype=torch.float32))/2).long().item()
    
    def remap(self, img, grid):
        return torch.nn.functional.grid_sample(
            img, grid.clamp(-1,1).unsqueeze(0).expand(img.size(0),*3*[-1]), mode='bilinear', align_corners=True
        )
        
    def xy_grid(self):
        return torch.meshgrid(*[torch.linspace(1,-1,self.H),torch.linspace(1,-1,self.W)], indexing="xy")
    
    def pol2cartgrid(self):
        radius = self.get_radius()
        indices = torch.complex(*self.xy_grid())
        return (radius*indices.abs()/math.sqrt(2)).log()/math.log(radius)*2-1, indices.angle()/torch.pi 
    
    def thetarho_grid(self):
        radius = (torch.norm(torch.tensor([self.H,self.W], dtype=torch.float32))/2).long().item()
        theta, r = torch.meshgrid(*[torch.arange(360),torch.arange(radius)], indexing="ij")
        return theta, radius ** (r / radius)
    
    def cart2polgrid(self):
        theta, rho = self.thetarho_grid()
        indices = torch.polar(rho, theta * torch.pi / 180)
        return indices.real +self.W // 2, indices.imag + self.H // 2
    
    def cart2pol(self, img):
        xInds, yInds = self.cart2polgrid()
        grid = torch.stack([(xInds / (img.shape[-2] - 1)) * 2 - 1, (yInds / (img.shape[-1] - 1)) * 2 - 1], dim=-1)
        return self.remap(
            img, grid
        )
        
    def pol2cart(self, img):
        grid = torch.stack(self.pol2cartgrid(), dim=-1)
        return self.remap(
            img, grid
        )       
class MellinFourierRegistration(torch.nn.Module):
    def __init__(self, H, W):
        super().__init__()
        self.logPolar = LogPolarRepresentation(H, W)
        self.highPassFilter = HighPassFilter()
    
    def transform_rot_parameter(self, rotIdx):
        return -rotIdx +90
    def transform_scale_parameter(self, scaleIdx, width):
        radius = (torch.norm(torch.tensor([self.logPolar.H, self.logPolar.W], dtype=torch.float32))/2).long().item()
        return (radius ** (-((scaleIdx - width//2)) / radius))
    def transform_translation_parameters(self, iMax, jMax, pcH, pcW):
        return iMax.where(iMax<pcH//2, iMax-pcH),jMax.where(jMax<pcW//2, jMax-pcW)
    def get_rot_scale(self, pc_rot_scale, ):
        topPc = pc_rot_scale.abs().flatten(-2).topk(1, dim=(-1))
        predsRotsScale = torch.stack(torch.unravel_index(topPc.indices, pc_rot_scale.shape[-2:]), -2)
        estRot = self.transform_rot_parameter(predsRotsScale[:,0,0])
        estScale = self.transform_scale_parameter(predsRotsScale[:,1], pc_rot_scale.size(-1))[...,0] 
        return estRot, estScale
    
    def get_translations(self, pc_translat ):
        iMax, jMax =  torch.unravel_index(pc_translat.flatten(-2).sum(-2).argmax(-1), pc_translat.shape[-2:]) 
        estTrans = torch.stack(self.transform_translation_parameters(iMax, jMax, *pc_translat.shape[-2:]), dim=-1)
        return estTrans
    
    def forward(self, image, template):
        imageFft = torch.fft.fftshift(torch.fft.fft2(image))
        templateFft = torch.fft.fftshift(torch.fft.fft2(template))
        imageFftAbs = self.highPassFilter(imageFft)
        templateFftAbs =  self.highPassFilter(templateFft)
        imageLogPolar = self.logPolar.cart2pol(imageFftAbs)
        templateLogPolar = self.logPolar.cart2pol(templateFftAbs)
        pcRotScale = PhaseCorrelation(shift=True)(imageLogPolar[...,:180,:],templateLogPolar[...,:180,:]).sum(-3)
        estRot, estScale = self.get_rot_scale(pcRotScale)
        templateUnrotUnscaled = RigidTransform(1/estScale, torch.zeros_like(estScale), torch.zeros_like(estScale), -estRot )(template)
        pcTranslat = PhaseCorrelation(shift=False)(image,templateUnrotUnscaled)
        estTrans = self.get_translations(pcTranslat)
        return {
            'estRot': estRot,
            'estScale': estScale,
            'estTrans': estTrans,
            'pcRotScale': pcRotScale,
            'pcTranslat': pcTranslat
        } 
        
    def register_image(self,image, template):
        params = self(image, template)
        estTransTransf= torch.einsum("bcd, bd -> bc", RigidTransform.rotmat_2d(params['estRot']), params['estTrans'].float())/params['estScale'].unsqueeze(-1)
        imageTransfInv= RigidTransform(1/params['estScale'],-estTransTransf[:,1], -estTransTransf[:,0], -params['estRot'])(template)
        return dict(registered=imageTransfInv, params=params)
    
def scale_loss(estScale, gtScale):
    return ((estScale/ gtScale -1)**2).sum(-1)
def rot_loss(estRot, gtAngle):
    return ((((estRot)- gtAngle)/180)**2).sum(-1)


image = torchvision.io.read_image('public/mandrill.png')/255.
N=20
bimage = image.unsqueeze(0).expand(N,-1,-1,-1)
gtScale = torch.rand(N)+0.55
gtAngle = torch.randint(-90, 90, (N,)).float()
gtTransX = 75*(torch.rand(N)*2-1)
gtTransY = 75*(torch.rand(N)*2-1)
bimageTransf = RigidTransform(gtScale, gtTransX, gtTransY, gtAngle)(bimage)

In [2]:
mf = MellinFourierRegistration(*bimage.shape[-2:])
ans = mf.register_image(bimage, bimageTransf)
imageTransfRegistered = ans['registered']


In [49]:
nQuery = 2
H,W = ans["params"]["pcRotScale"].shape[-2:]
randomRotScale = torch.unravel_index(torch.arange(H*W), (H,W))

In [50]:
randomRotScale

(tensor([  0,   0,   0,  ..., 179, 179, 179]),
 tensor([  0,   1,   2,  ..., 359, 360, 361]))

In [52]:
torch.stack([mf.transform_rot_parameter(randomRotScale[0]),mf.transform_scale_parameter(randomRotScale[1], W)],-1)

tensor([[ 9.0000e+01,  1.9026e+01],
        [ 9.0000e+01,  1.8719e+01],
        [ 9.0000e+01,  1.8417e+01],
        ...,
        [-8.9000e+01,  5.5189e-02],
        [-8.9000e+01,  5.4298e-02],
        [-8.9000e+01,  5.3421e-02]])

In [55]:
H,W = ans["params"]["pcTranslat"].shape[-2:]
randomTranslat = torch.unravel_index(torch.arange(H*W), (H,W))

In [56]:
mf.transform_translation_parameters(randomTranslat[0],randomTranslat[1], H,W)

(tensor([ 0,  0,  0,  ..., -1, -1, -1]),
 tensor([ 0,  1,  2,  ..., -3, -2, -1]))

In [None]:

fig, axs = plt.subplots(10, 2, figsize=(8, 24))
for i in range(10):
    axs[i, 0].imshow(bimageTransf[i].moveaxis(0,-1))
    axs[i, 0].set_title('Rigid Transformed')
    axs[i, 0].axis('off')
    axs[i, 1].imshow(imageTransfRegistered[i].moveaxis(0,-1))
    axs[i, 1].set_title('Registered Image')
    axs[i, 1].imshow(bimage[i].moveaxis(0,-1), alpha=0.5)
    axs[i, 1].axis('off')
plt.tight_layout()
plt.show()

In [None]:
qsdqsd

In [None]:
plt.imshow(LogPolarRepresentation(*bimage.shape[-2:]).cart2pol(
        bimage
)[0].mean(0))

In [None]:
plt.imshow(
        bimageTransf
[6].mean(0))

In [None]:
plt.imshow(LogPolarRepresentation(*bimage.shape[-2:]).cart2pol(
        bimageTransf
)[6].mean(0))

In [None]:
plt.imshow(LogPolarRepresentation(*bimage.shape[-2:]).pol2cart(LogPolarRepresentation(*bimage.shape[-2:]).cart2pol(
        bimage
))[6].mean(0))
plt.colorbar()
plt.imshow(bimage[6].mean(0), alpha=0.5)

In [None]:
hp_filt = HighPassFilter()
logpol = LogPolarRepresentation(*bimage.shape[-2:])

bimageFft = fft2d(bimage)
bimageTransfFft = fft2d(bimageTransf)
bimageFftAbs = hp_filt(bimageFft)
bimageFftTransfAbs =  hp_filt(bimageTransfFft)

In [None]:
bimageLogPolMag = logpol.cart2pol(bimageFftAbs)
bimageTransfLogPolMag = logpol.cart2pol(bimageFftTransfAbs)

In [None]:
plt.imshow(bimageLogPolMag[6].mean(0))

In [None]:
bimageLogPolMag.shape

In [None]:
pc = PhaseCorrelation()(bimageLogPolMag[...,:180,:],bimageTransfLogPolMag[...,:180,:]).sum(-3)

plt.imshow(pc.abs()[0])

In [None]:
topPc = pc.abs().flatten(-2).topk(1, dim=(-1))
predsRotsScale = torch.stack(torch.unravel_index(topPc.indices, pc.shape[-2:]), -2)
estRot = -predsRotsScale[:,0,0] + 90
radius = (torch.norm(torch.tensor(bimage.shape[-2:], dtype=torch.float32))/2).long().item()
estScale = (radius ** (-((predsRotsScale[:,1] - pc.size(-1)//2)) / radius))[...,0] 

In [None]:
(estRot-gtAngle)%180

In [None]:
scale_loss(estScale,gtScale),rot_loss(estRot, gtAngle)

In [None]:
bimageUnrotUnscaled = RigidTransform(1/estScale, torch.zeros_like(estScale), torch.zeros_like(estScale), -estRot )(bimageTransf)

In [None]:
pc_translat = PhaseCorrelation(shift=False)(bimage,bimageUnrotUnscaled)
iMax, jMax =  torch.unravel_index(pc_translat.flatten(-2).sum(-2).argmax(-1), pc_translat.shape[-2:]) 
estTrans = torch.stack([iMax.where(iMax<pc_translat.size(-2)//2, iMax-pc_translat.size(-2)),jMax.where(jMax<pc_translat.size(-1)//2, jMax-pc_translat.size(-1))], dim=-1)
gtTrans = torch.stack([gtTransY, gtTransX], dim=-1)

In [None]:
gtTrans.long()-estTrans

In [None]:
gtTrans

In [None]:
estTransTransf= torch.einsum("bcd, bd -> bc", RigidTransform.rotmat_2d(estRot), estTrans.float())/estScale.unsqueeze(-1)
bimageUntransf= RigidTransform(1/estScale,-estTransTransf[:,1], -estTransTransf[:,0], -estRot)(bimageTransf)
fig, axs = plt.subplots(10, 3, figsize=(8, 24))
for i in range(10):
    axs[i, 0].imshow(bimageTransf[i].mean(0))
    axs[i, 0].set_title('Rigid Transformed')
    axs[i, 0].axis('off')
    axs[i, 1].imshow(bimageUnrotUnscaled[i].mean(0))
    axs[i, 1].set_title('Unrot & Unscaled Image')
    axs[i, 1].axis('off')
    axs[i, 2].imshow(bimageUntransf[i].mean(0))
    axs[i, 2].set_title('Registered Image')
    axs[i, 2].imshow(bimage[i].mean(0), alpha=0.5)
    axs[i, 2].axis('off')
plt.tight_layout()
plt.show()

In [None]:
estTrans[:,1]

In [None]:
estTransTransf