In [2]:
import torch
import torch.nn as nn
#from model.cabm import CBAM
import torch.nn.functional as F
from torch.autograd import Variable

In [3]:
class AttentionCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(AttentionCell, self).__init__()
        self.i2h = nn.Linear(input_size, hidden_size,bias=False)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.score = nn.Linear(hidden_size, 1, bias=False)
        self.rnn = nn.GRUCell(input_size, hidden_size)
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.processed_batches = 0

    def forward(self, prev_hidden, feats):
        self.processed_batches = self.processed_batches + 1
        nC = feats.size(0)
        nB = feats.size(1)
        nT = feats.size(2)
        hidden_size = self.hidden_size
        input_size = self.input_size

        feats_proj = self.i2h(feats.view(-1,nC))
        prev_hidden_proj = self.h2h(prev_hidden).view(1,nB, hidden_size).expand(nT, nB, hidden_size).contiguous().view(-1, hidden_size)
        emition = self.score(torch.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
        alpha = F.softmax(emition, dim=1) # nB * nT

        if self.processed_batches % 10000 == 0:
            print('emition ', list(emition.data[0]))
            print('alpha ', list(alpha.data[0]))

        feats=feats.transpose(0, 2)
        context = (feats * alpha.transpose(0,1).contiguous().view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0)
        cur_hidden = self.rnn(context, prev_hidden)
        return cur_hidden, alpha

class Attention(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Attention, self).__init__()
        self.attention_cell = AttentionCell(input_size, hidden_size)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.generator = nn.Linear(hidden_size, num_classes)
        self.processed_batches = 0

    def forward(self, feats, text_length):
        self.processed_batches = self.processed_batches + 1
        nC = feats.size(0)
        nB = feats.size(1)
        nT = feats.size(2)
        hidden_size = self.hidden_size
        input_size = self.input_size
        print("in=", input_size)
        assert(input_size == nC)
        assert(nB == text_length.numel())

        num_steps = text_length.data.max()
        num_labels = text_length.data.sum()

        output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
        hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data))
        max_locs = torch.zeros(num_steps, nB)
        max_vals = torch.zeros(num_steps, nB)
        for i in range(num_steps):
            hidden, alpha = self.attention_cell(hidden, feats)
            output_hiddens[i] = hidden
            if self.processed_batches % 500 == 0:
                max_val, max_loc = alpha.data.max(1)
                max_locs[i] = max_loc.cpu()
                max_vals[i] = max_val.cpu()
        if self.processed_batches % 500 == 0:
            print('max_locs', list(max_locs[0:text_length.data[0],0]))
            print('max_vals', list(max_vals[0:text_length.data[0],0]))
        new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
        b = 0
        start = 0
        for length in text_length.data:
            new_hiddens[start:start+length] = output_hiddens[0:length,b,:]
            start = start + length
            b = b + 1
        probs = self.generator(new_hiddens)
        return probs

In [4]:
at=AttentionCell(251, 251)
xt=Attention(251, 251, 11)

In [41]:
prev_hidden=torch.rand(10, 251)
hidden_size=251
feats=torch.rand(251, 10, 11)
nT = feats.size(0)
nB = feats.size(1)
nC = feats.size(2)
print(nT, nB, nC)
feats_proj = at.i2h(feats.view(-1, nT))
print("feats_proj.shape=", feats_proj.shape)
y=at.h2h(prev_hidden)
print
prev_hidden_proj = at.h2h(prev_hidden).view(1,nB, hidden_size).expand(nC, nB, hidden_size).contiguous().view(-1, hidden_size)
print("prev_hidden_proj =", prev_hidden_proj.shape)
x=torch.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)
print('x=' , x.shape)
x=at.score(x)
print('x=' , x.shape)
x=x .view(nC,nB)
print('x=' , x.shape)
emition = at.score(torch.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nC,nB).transpose(0,1)
print("enmition=", emition.shape)
# alpha = F.softmax(emition, dim=1)
# print('alpha=', alpha.shape)
# # alpha =alpha.transpose(0, 1)
# alpha=alpha.transpose(0,1).contiguous().view(nC,nB,1).expand(nC, nB, nT)
# print(alpha.shape)
# feats=feats.transpose(0, 2)
# print(feats.shape)
# x=feats*alpha
# x=x.sum(0).squeeze(0)
# x.shape
# #(feats*alpha).sum(0).squeeze(0).shape
# #feats*(alpha.transpose(0,1).contiguous().view(nC,nB,1).expand(nC, nB, nT))
# #context = (feats * alpha.transpose(0,1).contiguous().view(nC,nB,1).expand(nC, nB, nT)).sum(0).squeeze(0)
# cur_hidden = at.rnn(x, prev_hidden)
# cur_hidden.shape

251 10 11
feats_proj.shape= torch.Size([110, 251])
prev_hidden_proj = torch.Size([110, 251])
x= torch.Size([110, 251])
x= torch.Size([110, 1])
x= torch.Size([11, 10])
enmition= torch.Size([10, 11])


In [6]:
feats=torch.rand(251, 10, 11)
nC = feats.size(0)
nB = feats.size(1)
nT = feats.size(2)
text_length=torch.tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
hidden_size = xt.hidden_size
input_size = xt.input_size
print("in=", input_size)
assert(input_size == nC)
assert(nB == text_length.numel())
num_steps = text_length.data.max()
num_labels = text_length.data.sum()
print('(num_labels.shape=', num_labels.shape)

in= 251
(num_labels.shape= torch.Size([])
