In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import torchtext
from torchtext import data
from nltk import word_tokenize, sent_tokenize
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm_notebook as tqdm

from typing import Dict, Iterable, Tuple, List

In [2]:
def make_examples(df: pd.DataFrame, fields: Dict[str, data.Field]):
    fields = {field_name: (field_name, field)
                       for field_name, field in fields.items()}
    for _, row in tqdm(df.iterrows()):
        example = data.Example.fromdict(row, fields)
        yield example

In [3]:
class SequenceClassifierAttention(nn.Module):
    # this follows the word-level attention from Yang et al. 2016
    # https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf
    
    def __init__(self, n_hidden, *, batch_first=False):
        super().__init__()
        self.mlp = nn.Linear(n_hidden, n_hidden)
        # word context vector
        self.u_w = nn.Parameter(torch.rand(n_hidden))
        self.batch_first = batch_first
    
    def forward(self, X):
        if not self.batch_first:
            # make the input (batch_size, timesteps, features)
            X = X.transpose(1, 0)
                                
        # get the hidden representation of the sequence
        u_it = F.tanh(self.mlp(X))
        # get attention weights for each timestep
        alpha = F.softmax(torch.matmul(u_it, self.u_w), dim=1)
        
        # get the weighted representation of the sequence
        # and then get the sum
        # (add a size 1 dimension to alpha so each time step's features could be scaled)
        weighted_sequence = X * alpha.unsqueeze(2)
        out = torch.sum(weighted_sequence, dim=1)
        return out, alpha

In [4]:
def get_sequence_lengths(sequences, padding_value):
    def get_seq_len(sequence):
        seq_len = 0
        for n, item in enumerate(sequence, 1):
            if item == padding_value:
                break
            seq_len = n
        return seq_len
    lengths = np.array([get_seq_len(seq) for seq in sequences])
    return lengths

In [5]:
def get_nonzero_sequences(sequences, lengths):
    # keep track of the indexes
    # of non-empty sequences
    indexes = torch.nonzero(lengths).view(-1)
    nonzero_seqs = sequences[indexes]
    return nonzero_seqs, indexes

In [6]:
class HierarchicalAttentionNetwork(nn.Module):
    def __init__(self, *, n_hidden: int, n_classes: int,
                 vocab_size, embedding_dim, embedding_weights=None,
                 padding_idx=None):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim, 
                                  padding_idx=padding_idx)
        if embedding_weights is not None:
            self.embed.data.weight.copy_(embedding_weights)
            
        self.word_encoder = nn.GRU(embedding_dim, n_hidden, bidirectional=True,
                                   batch_first=True)
        self.word_attention = SequenceClassifierAttention(n_hidden * 2, 
                                                          batch_first=True)
        self.sentence_encoder = nn.GRU(n_hidden * 2, n_hidden, bidirectional=True,
                                       batch_first=True)
        self.sentence_attention = SequenceClassifierAttention(n_hidden * 2,
                                                              batch_first=True)
        self.out = nn.Linear(n_hidden * 2, n_classes)
        
    
    @staticmethod
    def _repack_with_zero_seqs(*, seq_vectors: torch.Tensor, batch_size: int,
                               indexes_nonzero: List[int]):
        # unfortunately,
        # this is the only way I have found to "repackage"
        # the sequence vectors into the original batch size
        # and preserving the gradient graph
        encoded_dim = seq_vectors.shape[1]
        repacked_seqs = [torch.zeros(encoded_dim, device=seq_vectors.device,
                                     requires_grad=True).unsqueeze(0)
                         for _ in range(batch_size)]
        for seq_vector, index in zip(seq_vectors, indexes_nonzero):
            repacked_seqs[index] = seq_vector.unsqueeze(0)
        
        repacked_seqs = torch.cat(repacked_seqs)
        return repacked_seqs
        
    def encode_sentences_words(self, X):
        batch_size, n_sents, n_words = X.shape
        encoded_sents_word = []
        sentence_alphas = []
        sentence_lengths = torch.zeros(batch_size, dtype=torch.int32)
        
        for i in range(n_sents):
            sentence_words = X[:,i,:]
            lengths = get_sequence_lengths(sentence_words, self.embed.padding_idx)
            lengths = torch.tensor(lengths, dtype=torch.int64)
            
            sentence_words_nonzero, indexes_nonzero = get_nonzero_sequences(sentence_words,
                                                                            lengths)
            lengths_nonzero = lengths[indexes_nonzero]
            if indexes_nonzero.dim() > 0:
                sorted_lengths, sorted_indices = torch.sort(lengths_nonzero, 
                                                            descending=True)
                sentence_words = sentence_words_nonzero[sorted_indices]
            else:
                sorted_lengths, sorted_indices = lengths_nonzero, indexes_nonzero
                sentence_words = sentence_words_nonzero.unsqueeze(0)

            words_embedded = self.embed(sentence_words)
            words_embedded = pack_padded_sequence(words_embedded, sorted_lengths, 
                                                      batch_first=True)
            words_encoded, _ = self.word_encoder(words_embedded)
            words_encoded, sorted_lengths = pad_packed_sequence(words_encoded,
                                                                batch_first=True)
            
            _, unsorted_indices = torch.sort(sorted_indices)
            words_encoded = words_encoded[unsorted_indices]
            
            sentence_vector, sentence_alpha = self.word_attention(words_encoded)
            
            # "re-insert" zero vectors as placeholders
            # for the sequences (sentences) that have already ended
            sentence_vector = self._repack_with_zero_seqs(seq_vectors=sentence_vector,
                                                          batch_size=batch_size,
                                                          indexes_nonzero=indexes_nonzero)
            sentence_alpha = self._repack_with_zero_seqs(seq_vectors=sentence_alpha,
                                                         batch_size=batch_size,
                                                         indexes_nonzero=indexes_nonzero)
            
            # unsqueeze the sentence vector to insert dummy "sentence timestep" dimension
            # so that we can concatenate on it
            
            encoded_sents_word.append(sentence_vector.unsqueeze(1))
            sentence_alphas.append(sentence_alpha)
            
            # keep track of the sentence lengths
            # (increment them if they were non-zero)
            # so we can pad/pack the sentences
            # when encoding the whole document later
            for index in indexes_nonzero:
                sentence_lengths[index] += 1
            
        encoded_sents_word = torch.cat(encoded_sents_word, dim=1)
        return encoded_sents_word, sentence_alphas, sentence_lengths
        
        
    def forward(self, X):
        encoded_sents_word, sentence_alphas, sentence_lengths = self.encode_sentences_words(X)
        encoded_sents, _ = self.sentence_encoder(encoded_sents_word)
        encoded_docs, document_alpha = self.sentence_attention(encoded_sents)
        out = self.out(encoded_docs)
        return out, sentence_alphas, document_alpha
        


In [7]:
data_path = Path(Path.home(), 'projects', 'nlp', 'yelp_data')
file_path = Path(data_path, 'review_chunk.gz')
assert file_path.exists()

review_chunk = pd.read_json(file_path, orient='records', lines=True,
                            chunksize=10000)

In [8]:
review_chunk = next(review_chunk)
review_chunk.head()

Unnamed: 0,business_id,cool,date,funny,review_id,stars,text,useful,user_id
0,0W4lkclzZThpx3V65bVgig,0,2016-05-28,0,v0i_UHJMo_hPBq9bxWvW4w,5,"Love the staff, love the meat, love the place....",0,bv2nCi5Qv5vroFiqKGopiw
1,AEx2SYEUJmTxVVB18LlCwA,0,2016-05-28,0,vkVSCC7xljjrAI4UGfnKEQ,5,Super simple place but amazing nonetheless. It...,0,bv2nCi5Qv5vroFiqKGopiw
2,VR6GpWIda3SfvPC-lg9H3w,0,2016-05-28,0,n6QzIUObkYshz4dz2QRJTw,5,Small unassuming place that changes their menu...,0,bv2nCi5Qv5vroFiqKGopiw
3,CKC0-MOWMqoeWf6s-szl8g,0,2016-05-28,0,MV3CcKScW05u5LVfF6ok0g,5,Lester's is located in a beautiful neighborhoo...,0,bv2nCi5Qv5vroFiqKGopiw
4,ACFtxLv8pGrrxMm6EgjreA,0,2016-05-28,0,IXvOzsEMYtiJI0CARmj77Q,4,Love coming here. Yes the place always needs t...,0,bv2nCi5Qv5vroFiqKGopiw


In [9]:

words = data.Field(tokenize=word_tokenize, batch_first=True)
sentences = data.NestedField(words, tokenize=sent_tokenize)

label = data.Field(sequential=False, use_vocab=False)

fields = {'text': sentences, 'stars': label}

In [10]:
examples = list(make_examples(review_chunk, fields=fields))
dataset = data.Dataset(examples, fields)
device = torch.device('cuda:0')

input_name = 'text'
batch_size = 16
iterator = data.BucketIterator(dataset, batch_size, device=device, repeat=False,
                               sort_key=lambda x: len(getattr(x, input_name)))




In [11]:
sentences.build_vocab(dataset, max_size=50000)

In [12]:

n_classes = 5
n_hidden = 50
embedding_dim = 200
vocab_size = len(sentences.vocab)
padding_idx = sentences.vocab.stoi[sentences.pad_token]

net = HierarchicalAttentionNetwork(n_hidden=n_hidden, 
                                   padding_idx=padding_idx,
                                   embedding_dim=embedding_dim, 
                                   vocab_size=vocab_size, 
                                   n_classes=n_classes)
net.to(device)
net

HierarchicalAttentionNetwork(
  (embed): Embedding(46323, 200, padding_idx=1)
  (word_encoder): GRU(200, 50, batch_first=True, bidirectional=True)
  (word_attention): SequenceClassifierAttention(
    (mlp): Linear(in_features=100, out_features=100, bias=True)
  )
  (sentence_encoder): GRU(100, 50, batch_first=True, bidirectional=True)
  (sentence_attention): SequenceClassifierAttention(
    (mlp): Linear(in_features=100, out_features=100, bias=True)
  )
  (out): Linear(in_features=100, out_features=5, bias=True)
)

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

In [14]:
n_epochs = 5

report_every = 100

for epoch in range(n_epochs):
    print(f'Epoch {epoch+1}')
    running_loss = 0
    for n, batch in enumerate(tqdm(iterator)):
        optimizer.zero_grad()
        
        inputs = batch.text
        targets = batch.stars - 1
        outputs, *_ = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += float(loss)
        
        if n % report_every == report_every - 1:
            mean_loss = running_loss / report_every
            print('Mean loss:', mean_loss)
            running_loss = 0

Epoch 1


Mean loss: 1.472281777858734
Mean loss: 1.4550573337078094
Mean loss: 1.4672764348983764
Mean loss: 1.4358659839630128
Mean loss: 1.4158501279354097
Mean loss: 1.3958953130245209

Epoch 2


Mean loss: 1.3782466781139373
Mean loss: 1.3051426988840102
Mean loss: 1.2911820089817048
Mean loss: 1.2696714252233505
Mean loss: 1.2692230015993118
Mean loss: 1.216897214651108

Epoch 3


Mean loss: 1.1767490357160568


KeyboardInterrupt: 

In [None]:
some_batch = next(iter(iterator))

In [None]:
first_sent = some_batch.text[:,0,]
get_sequence_lengths(first_sent, padding_idx)

In [None]:
first_sent = some_batch.text[:,-1,]
first_sent

In [None]:
outputs, *_ = net(some_batch.text)
outputs = F.softmax(outputs, dim=-1)
np.argmax(outputs.data.numpy(), axis=1) + 1

In [None]:
some_batch.stars.data.numpy()

In [22]:
t = torch.rand((5, 10, 3))
t

tensor([[[ 0.6832,  0.7560,  0.5971],
         [ 0.1242,  0.4542,  0.5848],
         [ 0.7981,  0.5346,  0.7373],
         [ 0.8094,  0.8537,  0.7586],
         [ 0.0591,  0.1274,  0.3652],
         [ 0.7516,  0.9491,  0.4236],
         [ 0.7122,  0.3119,  0.2438],
         [ 0.4165,  0.0557,  0.9150],
         [ 0.8173,  0.9313,  0.0369],
         [ 0.0041,  0.5269,  0.0507]],

        [[ 0.3365,  0.2361,  0.5023],
         [ 0.3859,  0.2031,  0.2038],
         [ 0.6666,  0.2204,  0.6155],
         [ 0.3849,  0.6044,  0.3613],
         [ 0.2057,  0.8306,  0.4697],
         [ 0.8115,  0.3566,  0.9043],
         [ 0.9975,  0.8262,  0.3989],
         [ 0.4226,  0.6329,  0.5481],
         [ 0.0649,  0.4659,  0.4612],
         [ 0.6775,  0.3589,  0.7245]],

        [[ 0.1666,  0.0239,  0.4086],
         [ 0.9947,  0.2983,  0.8580],
         [ 0.1936,  0.8046,  0.5299],
         [ 0.5637,  0.9643,  0.0417],
         [ 0.3955,  0.9059,  0.4972],
         [ 0.6955,  0.6288,  0.1333],
        