In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
from torch.autograd import Variable

In [2]:
class HAN_Model(nn.Module):
    def __init__(self,vocab_size,embedding_size,gru_size,class_num,is_pretrain=False,weights=None):
        super(HAN_Model, self).__init__()
        if is_pretrain:
            self.embedding = nn.Embedding.from_pretrained(weights, freeze=False)
        else:
            self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.word_gru = nn.GRU(input_size=embedding_size,hidden_size=gru_size,num_layers=1,
                               bidirectional=True,batch_first=True)
        self.word_context = nn.Parameter(torch.Tensor(2*gru_size, 1),requires_grad=True)
        self.word_dense = nn.Linear(2*gru_size,2*gru_size)

        self.sentence_gru = nn.GRU(input_size=2*gru_size,hidden_size=gru_size,num_layers=1,
                               bidirectional=True,batch_first=True)
        self.sentence_context = nn.Parameter(torch.Tensor(2*gru_size, 1),requires_grad=True)
        self.sentence_dense = nn.Linear(2*gru_size,2*gru_size)
        self.fc = nn.Linear(2*gru_size,class_num)
    def forward(self, x,gpu=False):
        sentence_num = x.shape[1] 
        sentence_length = x.shape[2]
        x = x.view([-1,sentence_length]) # x: bs*sentence_num*sentence_length -> (bs*sentence_num)*sentence_length
        x_embedding = self.embedding(x) # (bs*sentence_num)*sentence_length*embedding_size
        word_outputs, word_hidden = self.word_gru(x_embedding) # word_outputs.shape: (bs*sentence_num)*sentence_length*2gru_size
        word_outputs_attention = torch.tanh(self.word_dense(word_outputs)) # (bs*sentence_num)*sentence_length*2gru_size
        weights = torch.matmul(word_outputs_attention,self.word_context) # (bs*sentence_num)*sentence_length*1
        weights = F.softmax(weights,dim=1) # (bs*sentence_num)*sentence_length*1
        x = x.unsqueeze(2) # bs*sentence_num)*sentence_length*1
        if gpu:
            weights = torch.where(x!=0,weights,torch.full_like(x,0,dtype=torch.float).cuda())
        else:
            weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float)) # bs*sentence_num)*sentence_length*1

        weights = weights/(torch.sum(weights,dim=1).unsqueeze(1)+1e-4) # (bs*sentence_num)*sentence_length*1

        sentence_vector = torch.sum(word_outputs*weights,dim=1).view([-1,sentence_num,word_outputs.shape[-1]]) #bs*sentence_num*2gru_size
        sentence_outputs, sentence_hidden = self.sentence_gru(sentence_vector)# sentence_outputs.shape: bs*sentence_num*2gru_size
        attention_sentence_outputs = torch.tanh(self.sentence_dense(sentence_outputs)) # sentence_outputs.shape: bs*sentence_num*2gru_size
        weights = torch.matmul(attention_sentence_outputs,self.sentence_context) # sentence_outputs.shape: bs*sentence_num*1
        weights = F.softmax(weights,dim=1) # sentence_outputs.shape: bs*sentence_num*1
        x = x.view(-1, sentence_num, x.shape[1]) # bs*sentence_num*sentence_length
        x = torch.sum(x, dim=2).unsqueeze(2) # bs*sentence_num*1
        if gpu:
            weights = torch.where(x!=0,weights,torch.full_like(x,0,dtype=torch.float).cuda())
        else:
            weights = torch.where(x != 0, weights, torch.full_like(x, 0, dtype=torch.float)) #  bs*sentence_num*1
        weights = weights / (torch.sum(weights,dim=1).unsqueeze(1)+1e-4) # bs*sentence_num*1
        document_vector = torch.sum(sentence_outputs*weights,dim=1)# bs*2gru_size
        output = self.fc(document_vector) #bs*class_num
        return output


In [3]:
han_model = HAN_Model(vocab_size=30000,embedding_size=200,gru_size=50,class_num=4)
x = torch.Tensor(np.zeros([64,50,100])).long()
x[0][0][0:10] = 1
output = han_model(x)
print (output.shape)

torch.Size([64, 4])
