In [1]:
import torch
import torch as tr
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from spatial_correlation_sampler import SpatialCorrelationSampler

In [2]:
batch = 1
channel = 4
patch_size = 5
height = 14
width = 14

input1 = torch.randn(batch,
                     channel,
                     height,
                     width,
                     dtype=torch.float32,
                     requires_grad=True).cuda()
input2 = torch.randn_like(input1).cuda()

In [3]:
def corr_abs_to_rel(corr,h,w):
    max_d = patch_size // 2
        
    b,c,s = corr.size()        
    corr = corr.view(b,h,w,h,w)
    w_diag = tr.zeros((b,h,h,patch_size ,w),device='cuda')
    for i in range(max_d+1):
        if (i==0):
            w_corr_offset = tr.diagonal(corr,offset=0,dim1=2,dim2=4)       
            w_diag[:,:,:,max_d] = w_corr_offset
        else:
            w_corr_offset_pos = tr.diagonal(corr,offset=i,dim1=2,dim2=4) 
            w_corr_offset_pos = F.pad(w_corr_offset_pos, (i,0)) #.unsqueeze(5)
            w_diag[:,:,:,max_d-i] = w_corr_offset_pos
            w_corr_offset_neg = tr.diagonal(corr,offset=-i,dim1=2,dim2=4) 
            w_corr_offset_neg = F.pad(w_corr_offset_neg, (0,i)) #.unsqueeze(5)
            w_diag[:,:,:,max_d+i] = w_corr_offset_neg
    hw_diag = tr.zeros((b,patch_size ,w,patch_size ,h),device='cuda') 
    for i in range(max_d+1):
        if (i==0):
            h_corr_offset = tr.diagonal(w_diag,offset=0,dim1=1,dim2=2)
            hw_diag[:,:,:,max_d] = h_corr_offset
        else:
            h_corr_offset_pos = tr.diagonal(w_diag,offset=i,dim1=1,dim2=2) 
            h_corr_offset_pos = F.pad(h_corr_offset_pos, (i,0)) #.unsqueeze(5)
            hw_diag[:,:,:,max_d-i] = h_corr_offset_pos
            h_corr_offset_neg = tr.diagonal(w_diag,offset=-i,dim1=1,dim2=2) 
            h_corr_offset_neg = F.pad(h_corr_offset_neg, (0,i)) #.unsqueeze(5)      
            hw_diag[:,:,:,max_d+i] = h_corr_offset_neg                
    hw_diag = hw_diag.permute(0,3,1,4,2).contiguous()
    hw_diag = hw_diag.view(-1,patch_size *patch_size ,h*w)      
        
    return hw_diag

In [5]:
relu = nn.ReLU()
def L2normalize(x, d=1):
    eps = 1e-6
    norm = x ** 2
    norm = norm.sum(dim=d, keepdim=True) + eps
    norm = norm ** (0.5)
    return (x / norm)

def match_layer_mm(feature1, feature2):
    feature1 = L2normalize(feature1)
    feature2 = L2normalize(feature2)
    b, c, h1, w1 = feature1.size()
    b, c, h2, w2 = feature2.size()
    feature1 = feature1.view(b, c, h1 * w1)
    feature2 = feature2.view(b, c, h2 * w2)
#     corr =tr.einsum('abc,abd->adc',feature1,feature2)
#     corr = torch.matmul(feature2.transpose(1, 2), feature1) 
    corr = torch.bmm(feature2.transpose(1, 2), feature1)
    corr = corr.view(b, h2 * w2, h1 * w1) # Channel : target // Spatial grid : source
    corr = corr_abs_to_rel(corr,height,width).cuda() # (b,pp,hw)
    corr = relu(corr)
    return corr

correlation_sampler = SpatialCorrelationSampler(
    kernel_size=1,
    patch_size=patch_size ,
    stride=1,
    padding=0,
    dilation_patch=1)
def match_layer_scs(feature1, feature2):
    feature1 = L2normalize(feature1)
    feature2 = L2normalize(feature2)
    b, c, h1, w1 = feature1.size()
    b, c, h2, w2 = feature2.size()
    
    corr = correlation_sampler(feature1,feature2)  # (b,p,p,h,w)
    corr = corr.view(b,-1,h1*w1)
    corr = relu(corr)
    return corr

In [6]:
corr_mm = match_layer_mm(input1,input2)
corr_scs = match_layer_scs(input1,input2)
print (corr_mm.size(), corr_scs.size())

print (torch.allclose(corr_mm,corr_scs,atol=10e-6))
# print (corr_mm[0,0], corr_scs[0,0])

torch.Size([1, 25, 196]) torch.Size([1, 25, 196])
True
