In [169]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary

from random import randint

from polynomial_nets import CP_L3, CP_L3_sparse

from poly_VAE import Flatten, UnFlatten, VAE_CP_L3, VAE_CP_L3_sparse, VAE_CP_L3_sparse_LU, loss_fn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data import random_split

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import pandas as pd

In [170]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self, in_dim):
        super().__init__()
        
        # Construct the conv layers
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//2 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//2 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        
        # Initialize gamma as 0
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
        
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B * C * W * H)
            returns :
                out : self attention value + input feature 
                attention: B * N * N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        
        proj_query  = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0,2,1) # B * N * C
        proj_key =  self.key_conv(x).view(m_batchsize, -1, width*height) # B * C * N
        energy =  torch.bmm(proj_query, proj_key) # batch matrix-matrix product
        
        attention = self.softmax(energy) # B * N * N
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) # B * C * N
        out = torch.bmm(proj_value, attention.permute(0,2,1)) # batch matrix-matrix product
        out = out.view(m_batchsize,C,width,height) # B * C * W * H
        
        # Add attention weights onto input
        out = self.gamma*out + x
        return out, attention

In [171]:
width = 10
height = 10
in_dim = 6
m_batchsize = 5

In [172]:
gamma = nn.Parameter(torch.zeros(1))
softmax  = nn.Softmax(dim=-1)

In [173]:
query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//2 , kernel_size= 1)
key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//2 , kernel_size= 1)
value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)

In [174]:
tensor = query_conv.weight.to(torch.float64)

In [175]:
x = torch.randn(m_batchsize, in_dim, width, height)
m_batchsize,C,width ,height = x.size()

In [176]:
output = query_conv(x)
output2 = value_conv(x)

In [177]:
proj_query  = query_conv(x).view(m_batchsize, -1, width*height).permute(0,2,1)
proj_key =  key_conv(x).view(m_batchsize, -1, width*height) # B * C * N
energy =  torch.bmm(proj_query, proj_key)

In [178]:
print(proj_query.shape, proj_key.shape, energy.shape)

torch.Size([5, 100, 3]) torch.Size([5, 3, 100]) torch.Size([5, 100, 100])


In [179]:
attention = softmax(energy) # B * N * N
proj_value = value_conv(x).view(m_batchsize, -1, width*height)
print(proj_value.shape, attention.shape) # B * C * N
out = torch.bmm(proj_value, attention.permute(0,2,1)) # batch matrix-matrix product
print(out.shape)

out1 = out.view(m_batchsize,C,width,height)
print(out1.shape) # B * C * W * H

# Add attention weights onto input
out2 = gamma*out1 + x

torch.Size([5, 6, 100]) torch.Size([5, 100, 100])
torch.Size([5, 6, 100])
torch.Size([5, 6, 10, 10])


In [180]:
proj_value.shape

torch.Size([5, 6, 100])

In [181]:
batch_size = 2
input_size = 10
rank = 5

In [182]:
z = torch.randn(batch_size, input_size)

In [183]:
U_q = nn.Linear(input_size, rank)
U_k = nn.Linear(input_size, rank)
U_v = nn.Linear(input_size, rank)

In [184]:
q = U_q(z)
k = U_k(z)
v = U_v(z)

In [185]:
q.unsqueeze(2).shape

torch.Size([2, 5, 1])

In [186]:
k.unsqueeze(1).shape

torch.Size([2, 1, 5])

In [187]:
energy = torch.bmm(q.unsqueeze(2), k.unsqueeze(1))

In [188]:
energy_scaled = energy / rank ** 0.5

In [189]:
energy.shape

torch.Size([2, 5, 5])

In [190]:
rank ** 0.5

2.23606797749979

In [191]:
attention = energy_scaled.softmax(dim = -1)

In [195]:
attention.shape

torch.Size([2, 5, 5])

In [198]:
torch.sum(attention, axis=2)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)

In [193]:
v.unsqueeze(1).shape

torch.Size([2, 1, 5])

In [194]:
out = torch.bmm(v.unsqueeze(1), attention.permute(0,2,1))

In [76]:
out = attention @ v

In [51]:
out

tensor([[-0.5283, -0.0202,  0.4027, -0.2598,  0.0642],
        [-0.5338,  0.0027,  0.3663, -0.1216,  0.1013],
        [-0.5773,  0.0289,  0.4369, -0.3464,  0.0117]], grad_fn=<MmBackward0>)

In [28]:
out2 = gamma*out + z

RuntimeError: The size of tensor a (5) must match the size of tensor b (10) at non-singleton dimension 1

In [None]:
class Attention(nn.Module):
    def __init__(self, d, k, o):
        super(Attention, self).__init__()
        
        self.layer_U_q = nn.Linear(d, k, bias=False)
        self.layer_U_k = nn.Linear(d, k, bias=False)
        self.layer_U_v = nn.Linear(d, k, bias=False)

        self.layer_C = nn.Linear(k, o)   
        self.input_dimension = d 
        self.rank = k
        self.output_dimension = o 


    def forward(self, z):
        z = z.reshape(-1, self.input_dimension)
        q = self.layer_U_q(z)
        k = self.layer_U_k(z)
        v = self.layer_U_v(z)
        energy = torch.bmm(q.unsqueeze(2), k.unsqueeze(1)) / self.rank ** 0.5
        attention = energy.softmax(dim = -1)
        out = torch.bmm(v.unsqueeze(1), attention.permute(0,2,1))
        
        x = self.layer_C(out)
        return x