In [3]:
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe
import torch

USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda' if USE_CUDA else 'cpu')
torch.cuda.set_device(0)

class SNLI(object):
    def __init__(self, device):
        super(SNLI, self).__init__()
        # set up fields
        self.TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
        self.LABEL = data.Field(sequential=False)

        #make splits for data
        self.train, self.dev, self.test = datasets.SNLI.splits(self.TEXT, self.LABEL)

        #build the vocabulary
        self.TEXT.build_vocab(self.train,vectors=GloVe(name='6B', dim=300))
        self.LABEL.build_vocab(self.train)

        # make iterator for splits
        self.train_iter, self.dev_iter, self.test_iter = data.BucketIterator.splits((self.train, self.dev, self.test), batch_size=32, device=device)
    
    @property
    def vocab_size(self):
        return len(self.TEXT.vocab)
    
    @property
    def num_class(self):
        return len(self.LABEL.vocab)

snli = SNLI(device=DEVICE)
print(snli.vocab_size)

56220


In [1]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

PAD_TOKEN='<pad>'
PAD_INDEX=snli.TEXT.vocab.stoi[PAD_TOKEN]

class NLI(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, hidden_size, num_layers, dropout, batch_size=32, bidirectional=True, device='cuda'):
        super(Encoder, self.__init__)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embedding_dim=embedding_dim
        self.bidirectional = bidirectional
        self.batch_size = batch_size
        self.device = device
        self.embedding = nn.Embedding(num_embeddings, embedding_dim).from_pretrained(TEXT.vocab.vectors, freeze=True)
        print(self.embedding[1])
        self.biLSTM = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, \
                        batch_first=True, bidirectional=bidirectional, dropout=dropout)
        self.prem_hidden, self.hypo_hidden = self._init_hidden(), self._init_hidden()
    
    def forward(self, raw_prem, raw_hypo):
        prem, premlen = raw_prem
        hypo, hypolen = raw_hypo
        prem_embed, hypo_embed = self.embedding(prem), self.embedding(hypo) # (batch_size, max_seqlen, hidden_size)
        prem_encoded, hypo_encoded = self.encode(prem, premlen, hypo, hypolen)  # (max_seqlen, batch_size, hidden_size)

        assert list(prem_encoded.size())==[max(premlen), self.batch_size, self.hidden_size]
        assert list(hypo_encoded.size())==[max(hypolen), self.batch_size, self.hidden_size]

        # mask for attention
        prem_att, hypo_att = self.attention(prem_encoded, hypo_encoded) # (batch_size, hidden_size)
        prem_avg, hypo_avg = torch.sum(prem_encoded, dim=0)/premlen.unsqueeze(-1)    \
                            torch.sum(hypo_encoded, dim=0)/hypolen.unsqueeze(-1)    # (batch_size, hidden_size)


    
    def attention(prem_encoded, hypo_encoded):
        prem_encoded.permute(1,0,2) # (batch_size, max_seqlen_p, hidden_size)
        hypo_encoded.permute(1,0,2) # (batch_size, max_seqlen_h, hidden_size)

        prem_masked = self._add_mask(prem_encoded, premlen)  # (batch_size, max_seqlen_p, hidden_size)
        hypo_masked = self._add_mask(hypo_encoded, hypolen)  # (batch_size, max_seqlen_h, hidden_size)
        
        hypo_masked.permute(0,2,1)  # (batch_size, hidden_size, max_seqlen_h)
        
        align = torch.bmm(prem_masked, hypo_masked) 
        att_weights_prem = F.softmax(align, dim=2) # (batch_size, max_seqlen_p, max_seqlen_h)
        att_weights_hypo = F.softmax(align, dim=1).permute(0,2,1) # (batch_size, max_seqlen_h, max_seqlen_p)

        prem_weight_sum = torch.sum(torch.bmm(att_weights_prem, hypo_encoded), dim=1) # (batch_size, hidden_size)
        hypo_weight_sum = torch.sum(torch.bmm(att_weights_hypo, prem_encoded), dim=1) # (batch_size, hidden_size)
        return prem_weight_sum, hypo_weight_sum

    def encode(self, raw_prem, raw_hypo):
        assert len(prem_embed.size()), len(hypo_embed.size())==(3, 3)
        prem_packed = pack_padded_sequence(prem_embed, premlen, enforce_sorted=False)
        hypo_packed = pack_padded_sequence(hypo_embed, hypolen, enforce_sorted=False)
        prem_lstm_out, _ , hypo_lstm_out, _ = self.biLSTM(prem_packed, self.prem_hidden), self.biLSTM(hypo_packed, self.hypo_hidden) 
        prem_encoded, hypo_encoded = pad_packed_sequence(prem_hprem_lstm_out)[0], pad_packed_sequence(hypo_lstm_out)[0] # (max_seqlen, batch_size, embedding_size)
        return prem_encoded, hypo_encoded
    
    def _add_mask(self, in_tensor, len_vector, batch_first=False):
        mask = torch.zeros_like(in_tensor)
        for i, length in enumerate(len_vector):
            if batch_first:
                mask[i,length:] = 1
            else:
                mask[length:,i] = 1
        masked_tensor = torch.masked_fill(in_tensor, mask.byte(), -float('inf'))
        return masked_tensor

    def _init_hidden(self):
        return torch.zeros(self.num_layers*(2 if bidirectional else 1), self.batch_size, self.hidden_size).to(self.device)


NameError: name 'snli' is not defined

In [58]:
class B(nn.Module):
    def __init__(self):
        super(B, self).__init__()
        self.embedding=nn.Embedding(10,50)
    
    def forward(self,x):
        return self.embedding(x)

lengths = [3,2,4,1]
a = torch.FloatTensor(torch.rand(size=(5,4)))
aa = nn.utils.rnn.pack_padded_sequence(a,lengths,enforce_sorted=False)
b = B()
print(b(aa))        

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not PackedSequence

PackedSequence(data=tensor([[6.],
        [2.],
        [7.],
        [3.],
        [8.],
        [9.]]), batch_sizes=tensor([2, 2, 1, 1]), sorted_indices=tensor([1, 0]), unsorted_indices=tensor([1, 0]))
PackedSequence(data=tensor([[-0.7407,  0.9941,  0.4728],
        [-0.4085,  0.9203, -0.2501],
        [-0.8257,  0.9987,  0.8237],
        [-0.7141,  0.9661,  0.2201],
        [-0.8365,  0.9995,  0.8878],
        [-0.8678,  0.9998,  0.9234]], grad_fn=<CatBackward>), batch_sizes=tensor([2, 2, 1, 1]), sorted_indices=tensor([1, 0]), unsorted_indices=tensor([1, 0]))
(tensor([[[-0.4085,  0.9203, -0.2501],
         [-0.7407,  0.9941,  0.4728]],

        [[-0.7141,  0.9661,  0.2201],
         [-0.8257,  0.9987,  0.8237]],

        [[ 0.0000,  0.0000,  0.0000],
         [-0.8365,  0.9995,  0.8878]],

        [[ 0.0000,  0.0000,  0.0000],
         [-0.8678,  0.9998,  0.9234]]], grad_fn=<IndexSelectBackward>), tensor([2, 4]))
