In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [0]:
class AAConv(nn.Module):
  def __init__(self,ip_dim,op_dim,dk,dv,dq,num_heads,ker_size,height,width):
    super(AAConv,self).__init__()
    self.ip_dim=ip_dim
    self.op_dim=op_dim
    self.dk=dk
    self.dv=dv
    self.dq=dq
    self.num_heads=num_heads
    self.ker_size=ker_size
    self.height=height
    self.width=width
    self.dk_per_head=self.dk//self.num_heads # We assume num_heads divides dk
    self.dv_per_head=self.dv//self.num_heads
    self.dq_per_head=self.dq//self.num_heads
    self.rel_embeddings_w=nn.Parameter(1/(self.dk_per_head**0.5)+torch.rand(2*width-1,self.dk_per_head),requires_grad=True)
    self.rel_embeddings_h=nn.Parameter(1/(self.dk_per_head**0.5)+torch.rand(2*height-1,self.dk_per_head),requires_grad=True)
    self.conv_qkv=nn.Conv2d(inp_dim,dk+dv+dq,1)
    self.softmax=nn.Softmax(dim=-1)
    self.attention_conv=nn.Conv2d(dv,dv,1)
    self.conv=nn.Conv2d(ip_dim,op_dim-dv,ker_size)

  def forward(self,x):
    qkv=self.lin_qkv(x)
    q,k,v=torch.split(qkv,[self.dq,self.dk,self.dv],dim=1)
    batch_size,_,H,W=q.size()
    q=q.view([batch_size,self.num_heads,dq_per_head,H*W])
    k=k.view([batch_size,self.num_heads,dk_per_head,H*W])
    v=v.view([batch_size,self.num_heads,dv_per_head,H*W])
    q=q/(self.dk_per_head**0.5)
    qktrans=torch.einsum('ijkl,ijkm -> ijlm',q,k)
    s_h,s_w=self.relative_pos_emeddings(q)
    weights=self.softmax(qktrans+s_h+s_w)
    attn=torch.einsum('ijkl,ijfl -> ijfk',weights,v)
    attn=attn.contiguous().view(batch_size,self.dv,H,W)
    attn_out=self.attention_conv(attn)
    conv_out=self.conv(x)
    op=torch.cat(conv_out,attn_out)
    return op
  
  def relative_pos_embeddings(self,q):
    bsz,num_heads,dkh,hw=q.size()
    q=q.view(bsz,num_heads,dkh,self.height,self.width)
    s_w=self.rel_1d(q,self.rel_emeddings_w,self.height,self.width,num_heads,[0,1,2,4,3,5])
    s_h=self.rel_1d(q.premute(0,1,2,4,3),self.rel_emeddings_h,self.height,num_heads,[0,1,4,2,5,3])
    return s_h,s_w

  def rel_1d(self,q,rel_k,h,w,num_heads,trans_mask):
    z=torch.einsum('bhdxy, md -> bhxym',q,rel_k).view([-1,num_heads*h,w,2*w-1])
    z=self.rel_to_abs(z)
    z=z.view([-1,num_heads,h*w,h*w])
    return z

  def rel_to_abs(self,x):
    bsz,num_heads,l,_=x.size()
    x=F.pad(x,(0,1),'constant',0)
    flat_x=x.view([bsz,num_heads,l*(2*l)])
    flat_x_padded=F.pad(flat_x,(0,l-1),'constant',0)
    x=flat_x_padded.view([bsz,num_heads,l+1,2*l-1])
    x=final_x[:,:,:l,l-1:]
    return x