In [893]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, BatchSampler
import torch.optim as optim
from transformers import BertTokenizer, BertModel
from nltk.tokenize import sent_tokenize

The get_lstm_features function returns the LSTM’s tag vectors. The function performs all the steps mentioned above for the model.
Steps:
1) It takes in characters, converts them to embeddings using our character LSTM.
2) We concat Character Embeding with embeding word vectors, use this as features that we feed to Bidirectional-LSTM.
3) The Bidirectional-LSTM generates outputs based on these set of features.
4) The output are passed through a linear layer to convert to tag space.

In [1038]:
START_TAG = 'B'
END_TAG = 'E'
NO_ENTITY_TAG = 'O'

START_SENTENCE_TOKEN = '[CLS]'
SEP_SENTENCE_TOKEN = '[SEP]'
PAD_TOKEN = '[PAD]'
UNKNOWN_TOKEN = '[UNK]'

START_SENTENCE_TOKEN_ID = 101
SEP_SENTENCE_TOKEN_ID = 102
PAD_TOKEN_ID = 0
UNKNOWN_TOKEN_ID = 100

SPECIAL_TOKENS = [START_SENTENCE_TOKEN, SEP_SENTENCE_TOKEN, PAD_TOKEN, UNKNOWN_TOKEN]
START_TEXT_TOKEN = '[PST]'
END_TEXT_TOKEN = '[PEN]'

In [1009]:
s = 'This is an example text. It demonstrates how to use embeddings layers or embedding lookup'
labels = ['O', 'O', 'B-A', 'I-A', 'E-A', 'O', 'B-A', 'O', 'O', 'O', 'O', 'B-A', 'E-A', 'O', 'B-A', 'E-A']
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', return_tensors='pt')
t = tokenizer(s, return_tensors='pt')
input_ids = t['input_ids'].squeeze().tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids)
word_ids = []

# Get the word IDs
word_ids = []
current_word_id = -1
for i, token_id in enumerate(input_ids):
    if token_id == tokenizer.cls_token_id or token_id == tokenizer.sep_token_id:
        word_ids.append(None)  # Special tokens like [CLS] and [SEP] have no corresponding word ID
        continue
    
    if tokens[i].startswith("##"):        
        word_ids.append(current_word_id)
        continue
        
    current_word_id += 1
    word_ids.append(current_word_id)
        
print(word_ids)

[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 11, 11, 12, 13, 14, 14, 14, 15, 15, None]


16

In [1014]:
from transformers import BertTokenizer

# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Example input sentence and corresponding labels
# input_sentence = "John lives in New York City."
# labels = ["B-PER", "O", "O", "B-LOC", "I-LOC", "I-LOC", "O"]
input_sentence = 'This is an example text. It demonstrates how to use embeddings layers or embedding lookup'
labels = ['O', 'O', 'B-A', 'I-A', 'E-A', 'O', 'B-A', 'O', 'O', 'O', 'O', 'B-A', 'E-A', 'O', 'B-A', 'E-A']

# Tokenize the input sentence
tokenized_input = tokenizer.tokenize(input_sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokenized_input)

# Initialize lists to store aligned tokens and labels
aligned_tokens = []
aligned_labels = []

# Initialize variables to keep track of token and label positions
token_pos = 0
label_pos = 0

# Loop through the tokenized input and align tokens with labels
while token_pos < len(token_ids):
    token = tokenized_input[token_pos]
    token_label = labels[label_pos]

    # Skip special tokens like [CLS], [SEP], [PAD]
    if token.startswith("##"):
        token_pos += 1
        continue

    # Add aligned token and label to the lists
    aligned_tokens.append(token)
    aligned_labels.append(token_label)

    # Move to the next token and label positions
    token_pos += 1

    # If the current token is part of a word (e.g., "New" and "York" in "New York City"),
    # increment the label position to match the next word
    if not token.endswith("##") and token_pos < len(token_ids) and not tokenized_input[token_pos].startswith("##"):
        label_pos += 1

print(len(token_ids))
print("Aligned Tokens:", aligned_tokens)
print("Aligned Labels:", aligned_labels)

22
Aligned Tokens: ['this', 'is', 'an', 'example', 'text', '.', 'it', 'demonstrates', 'how', 'to', 'use', 'em', 'layers', 'or', 'em', 'look']
Aligned Labels: ['O', 'O', 'B-A', 'I-A', 'E-A', 'O', 'B-A', 'O', 'O', 'O', 'O', 'B-A', 'B-A', 'E-A', 'O', 'O']


In [ ]:
from transformers import BertTokenizer, BertModel
import torch

# Initialize the BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Example input sentence and corresponding labels
input_sentence = "John lives in New York City."
labels = ["B-PER", "O", "O", "B-LOC", "I-LOC", "I-LOC", "O"]

# Tokenize the input sentence
tokenized_input = tokenizer.tokenize(input_sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokenized_input)

# Convert the token IDs to tensor
input_ids = torch.tensor(token_ids).unsqueeze(0)  # Add batch dimension

# Get token embeddings from the BERT model
with torch.no_grad():
    outputs = model(input_ids)
    token_embeddings = outputs.last_hidden_state

# Align token embeddings with labels
aligned_embeddings = []
aligned_labels = []

# Iterate through the token embeddings and align with labels
for idx, (token, label) in enumerate(zip(tokenized_input, labels)):
    # Skip special tokens like [CLS], [SEP], [PAD]
    if token.startswith("##"):
        continue
    
    # Add the token embedding and label to the aligned lists
    aligned_embeddings.append(token_embeddings[0, idx + 1])  # Skip [CLS] token at idx 0
    aligned_labels.append(label)

print("Number of aligned embeddings:", len(aligned_embeddings))
print("Number of aligned labels:", len(aligned_labels))

In [1059]:
class MultiSentenceDataset(Dataset):
    def __init__(self, file_dir, labels_dir=None, uncased=True, 
                 bert_model_name='bert-base-uncased', max_token_length=15):
        self.file_dir = file_dir
        self.labels_dir = labels_dir
        self.uncased = uncased
        self.max_token_length = max_token_length
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.texts_sentences, self.texts_mask, self.sentence_mask, self.labels = self.__read_data()
        self.tokens_ids, self.sentence_lengths, self.word_tokens, self.word_ids = self.__tokenize_sentences()
        self.chars_dict = self.__get_chars_dict()        
        self.chars_vocab_dim = len(self.chars_dict)
        self.chars_seq, self.chars_lengths = self.__tokenize_and_pad_characters_per_token()
        
        
    def __read_data(self):
        texts_sentences = []
        text_mask = []
        sentence_mask = []
        labels = []        
        text_enumerator = 0        
        for filename in os.listdir(self.file_dir):
            with open(os.path.join(self.file_dir, filename), 'r', encoding='utf-8') as file:                
                text = file.read()
                sentences = sent_tokenize(text)
                sentence_mask.extend(range(len(sentences)))
                text_mask.extend([text_enumerator]*len(sentences))
                texts_sentences.extend(sentences)
                
                text_enumerator += 1
            # if self.labels_dir:
            #     with open(os.path.join(self.labels_dir, filename), 'r', encoding='utf-8') as file:
            #         label = file.read()                    
            #         labels.append(NO_ENTITY_TAG + ' ' + label + ' ' + NO_ENTITY_TAG)
        return texts_sentences, text_mask, sentence_mask, labels if self.labels_dir else None
    
    def __get_words_from_tokens(self, sentence_tokens_ids):
        word_tokens = []
        word_ids = []
        tokens = self.tokenizer.convert_ids_to_tokens(sentence_tokens_ids)  
        word_id = -1        
        for token in tokens:
            if token.startswith("##"):
                word_tokens[-1] += token[2:]
            else:         
                word_id += 1
                word_tokens.append(token)
            word_ids.append(word_id)            
        return word_tokens, word_ids
        
    def __tokenize_sentences(self):
        tokens_ids = []
        word_tokens = []
        word_ids = []
        sentence_lengths = []
        for i, sentence in enumerate(self.texts_sentences):
            tokenized_sentence = self.tokenizer(sentence, return_tensors='pt') 
            sentence_tokens_ids = tokenized_sentence['input_ids'].squeeze().tolist()            
            tokens_ids.append(sentence_tokens_ids)
            
            sentence_word_tokens, sentence_word_ids = self.__get_words_from_tokens(sentence_tokens_ids)
            word_tokens.append(sentence_word_tokens)
            word_ids.append(sentence_word_ids)
            sentence_lengths.append(len(sentence_word_tokens))
            
        return tokens_ids, sentence_lengths, word_tokens, word_ids

    
    def __tokenize_and_pad_characters_per_token(self):
        encoded_char_seq = []
        chars_lengths = []
            
        for sentence_index, sentence_length in enumerate(self.sentence_lengths):    
            token_char_ids = []
            token_lengths = [] 
            for token_index in range(sentence_length):
                token = self.word_tokens[sentence_index][token_index]       
                if token in SPECIAL_TOKENS:
                    token_char_ids.append(([self.chars_dict[token]]))
                else:
                    token_char_ids.append(([self.chars_dict[c] for c in token]))         
                length = len(token_char_ids[-1])
                token_lengths.append(length)   
            encoded_char_seq.append(token_char_ids)
            chars_lengths.append(token_lengths)
        
        padded_encoded_char_seq = [[t + [0] * (self.max_token_length - len(t)) for t in s] for s in encoded_char_seq]
        return padded_encoded_char_seq, chars_lengths
    
    
    def __get_chars_dict(self):
        chars_freq = {START_SENTENCE_TOKEN: len(self.word_tokens), SEP_SENTENCE_TOKEN: len(self.word_tokens), 
             PAD_TOKEN: 0, UNKNOWN_TOKEN: 0}
        
        for sentence in self.word_tokens:
            for i, token in enumerate(sentence):
                if (token == START_SENTENCE_TOKEN) or (token == SEP_SENTENCE_TOKEN):
                    continue
                if token == PAD_TOKEN:
                    break                               
                if token == UNKNOWN_TOKEN:
                    chars_freq[token] += 1
                else: 
                    for c in token:
                        if c in chars_freq:
                            chars_freq[c] += 1
                        else:
                            chars_freq[c] = 0
        chars = dict(sorted(chars_freq.items(), key=lambda item: item[1], reverse=True))
        del chars[PAD_TOKEN]
        chars_vocab = {PAD_TOKEN: 0}
        chars_vocab.update(dict((item, i+1) for i, item in enumerate(chars)))
        return chars_vocab

    def __len__(self):
        return len(self.tokens_ids)

    def __getitem__(self, idx):   
        if self.labels is not None:
            pass
        else:
            return {
                    'token_ids': self.tokens_ids[idx], 
                    'words': self.word_tokens[idx],
                    'word_ids': self.word_ids[idx],
                    'sentence_length':  self.sentence_lengths[idx],
                    'text_mask': self.texts_mask[idx],
                    'sentence_mask': self.sentence_mask[idx],
                    'word_char_ids': self.chars_seq[idx],
                    'word_lengths': self.chars_lengths[idx],
                    }

In [1067]:
def sentence_data_collate_fn(batch):
    text_mask = []
    sentence_mask = []
    sentence_length = []
    
    token_ids = []
    token_mask = []    
    word_ids = []
    
    word_char_ids = []
    word_lengths = []
    
    max_sentence_length = 0
    max_tokens_count = 0
    
    for sentence in batch:
        if sentence['sentence_length'] > max_sentence_length:
            max_sentence_length = sentence['sentence_length']
        if len(sentence['token_ids']) > max_tokens_count:
            max_tokens_count = len(sentence['token_ids'])      
        sentence_length.append(sentence['sentence_length'])
        text_mask.append(sentence['text_mask'])
        sentence_mask.append(sentence['sentence_mask'])        
    
    for i, sentence in enumerate(batch):
        missing_sentence_length = max_sentence_length - sentence['sentence_length']
        missing_tokens_count = max_tokens_count - len(sentence['token_ids'])
        token_ids.append(sentence['token_ids'] + [0]*missing_tokens_count)
        word_ids.append(sentence['word_ids'] + [-1]*missing_tokens_count)
        token_mask.append([1]*len(sentence['token_ids']) + [0]*missing_tokens_count)

        padded_chars = sentence['word_char_ids'] + [[0]*MAX_TOKEN_LENGTH]*missing_sentence_length
        word_char_ids.append(padded_chars)
        word_lengths.append(sentence['word_lengths'] + [1]*missing_sentence_length)

    return {
                'token_ids': torch.tensor(token_ids), 
                'token_mask':  torch.tensor(token_mask),
                'word_ids': torch.tensor(word_ids),         
                'text_mask': torch.tensor(text_mask),
                'sentence_mask': torch.tensor(sentence_mask),
                'sentence_length': torch.tensor(sentence_length),
                'word_char_ids': torch.tensor(word_char_ids),
                'word_lengths': torch.tensor(word_lengths),
            }

In [1230]:
class WordBertEmbeddingModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', pretrain=False, average_hidden_states=False):        
        super(WordBertEmbeddingModel, self).__init__()
        self.average_hidden_states = average_hidden_states
        self.model = BertModel.from_pretrained(bert_model_name)   
        self.embeddings_dim = self.model.config.hidden_size
        
        if not pretrain:
            for param in self.model.parameters():
                param.requires_grad = False  
    
#     word_ids = batch['word_ids']
# token_embeddings = word_embeddings
# max_words_count = torch.max(batch['sentence_length']).item()
# embeddings_dim = 768
# 
# embeddings = []
# for sentence, sentence_embed in zip(word_ids, token_embeddings):
#     prev_id = -1
#     sentence_embeddings = []
#     prev_embedding = None
#     collected_embeddings = []
#     for word_id, token_embed in zip(sentence, sentence_embed):
#         if word_id == -1:
#             break
#             
#         if prev_id == -1:
#             print(f'prev_id: {prev_id}, prev_emb: {None}')
#             prev_id = word_id
#             prev_embedding = token_embed            
#             continue
#             
#         if word_id == prev_id:
#             collected_embeddings.append(prev_embedding)
#             print(f'prev_id: {prev_id}, prev_emb: {prev_embedding[:5]}, word_id: {word_id}, coll: {len(collected_embeddings)}')
#         else:
#             if len(collected_embeddings) > 0:
#                 collected_embeddings.append(prev_embedding)
#                 mean_emb = torch.tensor(np.mean(collected_embeddings, axis=0))
#                 sentence_embeddings.append(mean_emb)
#                 print(f'prev_id: {prev_id}, prev_emb: {prev_embedding[:5]}, word_id: {word_id}, coll: {len(collected_embeddings)}, mean_emb: {mean_emb[0:5]} {len(mean_emb)}')
#                 collected_embeddings = []
#             else:
#                 print(f'prev_id: {prev_id}, prev_emb: {prev_embedding[:5]}, word_id: {word_id}, coll: {len(collected_embeddings)}')
#                 sentence_embeddings.append(prev_embedding)
#             prev_id = word_id
#         prev_embedding = token_embed
# 
# 
#     if len(collected_embeddings) > 0:
#         collected_embeddings.append(prev_embedding)
#         sentence_embeddings.append(torch.tensor(np.mean(collected_embeddings, axis=0)))
#     else:
#         sentence_embeddings.append(prev_embedding)          
#     
#     embeddings.append(torch.stack(sentence_embeddings))
# 
# 
# for i, length in enumerate(batch['sentence_length']):
#     if length.item() < max_words_count:
#         embeddings[i] = torch.cat((embeddings[i], torch.zeros(max_words_count - length.item(), embeddings_dim)), dim=0)
# 
# res = torch.stack(embeddings)

                
    def __get_mean_embeddings_for_words(self, word_ids, token_embeddings):
        max_words_count = torch.max(batch['sentence_length']).item()
        embeddings = []
        for sentence, sentence_embed in zip(word_ids, token_embeddings):
            prev_id = -1
            sentence_embeddings = []
            prev_embedding = None
            collected_embeddings = []
            for word_id, token_embed in zip(sentence, sentence_embed):
                if word_id == -1:
                    break
                    
                if prev_id == -1:
                    prev_id = word_id
                    prev_embedding = token_embed
                    continue
                    
                if word_id == prev_id:
                    collected_embeddings.append(prev_embedding)                    
                else:
                    if len(collected_embeddings) > 0:
                        collected_embeddings.append(prev_embedding)
                        sentence_embeddings.append(torch.tensor(np.mean(collected_embeddings, axis=0)))
                        collected_embeddings = []
                    else:
                        sentence_embeddings.append(prev_embedding)
                    prev_id = word_id                    
                prev_embedding = token_embed
            
            if len(collected_embeddings) > 0:
                collected_embeddings.append(prev_embedding)
                sentence_embeddings.append(torch.tensor(np.mean(collected_embeddings, axis=0)))
            else:
                sentence_embeddings.append(prev_embedding)          
               
            embeddings.append(torch.stack(sentence_embeddings))
          
        
        for i, length in enumerate(batch['sentence_length']):
            if length.item() < max_words_count:
                embeddings[i] = torch.cat((embeddings[i], torch.zeros(max_words_count - length.item(), embeddings_dim)), dim=0)  
            
        return torch.stack(embeddings)
                    
        
    def forward(self, batch):
        x = self.model(input_ids= batch['token_ids'], attention_mask=batch['token_mask'])  
        if not self.average_hidden_states:
            x = x['last_hidden_state']
        else:
            x = torch.mean(torch.stack(x.hidden_states), dim=0)
        x = self.__get_mean_embeddings_for_words(batch['word_ids'], x)
        return x        

Char representation per word

In [1239]:
class CharLSTMModel(nn.Module):
    def __init__(self, vocab_dim, embedding_dim, output_dim, bidirectional=True, num_layers=1):        
        super(CharLSTMModel, self).__init__()              
        self.embedding_dim = embedding_dim
        self.embedding_layer = nn.Embedding(vocab_dim, embedding_dim, padding_idx=0)
        self.lstm_layer = nn.LSTM(embedding_dim, output_dim // 2 if bidirectional else output_dim, 
                                  bidirectional=bidirectional, 
                                  num_layers=num_layers, batch_first=True)        
    def forward(self, batch):

        (batch_size, sentence_max_length, token_max_length) = batch['word_char_ids'].shape
        x = self.embedding_layer(batch['word_char_ids'])
        flat_token_length = batch['word_lengths'].view(-1)
        x = x.view(batch_size * sentence_max_length, token_max_length, self.embedding_dim)
        x = torch.nn.utils.rnn.pack_padded_sequence(x, flat_token_length, batch_first=True, enforce_sorted=False)
        x, _ = self.lstm_layer(x)
        x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=token_max_length)
        #x = x.view(batch_size, sentence_max_length, token_max_length, x.shape[-1])
        x = x.view(batch_size, sentence_max_length, -1)
        return x

In [1052]:
class ContextLSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, bidirectional=True, num_layers=1,
                 input_dropout_rate=0.5, hidden_dropout_rate=0.5, output_dropout_rate=0.5,
                 init_hidden_to_random = True):
        super(ContextLSTMModel, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.bidirectional = bidirectional
        self.num_layers = num_layers
        self.init_hidden_to_random = init_hidden_to_random

        #LSTM layers
        self.lstm_layer = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, 
                            bidirectional=bidirectional, batch_first=True, 
                            dropout=hidden_dropout_rate if num_layers > 1 else 0)
        
        #Dropout layers for input and output
        if input_dropout_rate and input_dropout_rate > 0:
            self.dropout_input_layer = nn.Dropout(input_dropout_rate)
        
        if output_dropout_rate and output_dropout_rate > 0:
            self.dropout_output_layer = nn.Dropout(output_dropout_rate)
            
        #FC layer to map the LSTM output of into output space 
        self.fc_layer = nn.Linear(hidden_dim * (2 if bidirectional else 1), output_dim)

    def init_hidden_state(self, batch_size):
        if self.init_hidden_to_random:
            return torch.randn((2 if self.bidirectional else 1) * self.num_layers, batch_size, self.hidden_dim)
        else:
            return torch.zeros((2 if self.bidirectional else 1) * self.num_layers, batch_size, self.hidden_dim)
    
    def forward(self, sentences, sentence_lengths):
        #per batch size
        hidden = self.init_hidden_state(sentences.shape[0])
        state = self.init_hidden_state(sentences.shape[0])
        word_max_length = sentences.shape[-1]
        
        x = sentences
        if self.dropout_input_layer:
            x = self.dropout_input_layer(x)

        x = torch.nn.utils.rnn.pack_padded_sequence(x, sentence_lengths, batch_first=True, enforce_sorted=False)
        x, (self.hidden, self.state) = self.lstm_layer(x, (hidden, state))
        x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=word_max_length)
        
        if self.dropout_output_layer:
            x = self.dropout_output_layer(x)            
        x = self.fc_layer(x)        
        return x

In [1240]:
bert_model_name = 'bert-base-uncased'
file_dir = 'data_exp/texts'  # Directory containing text files
#labels_dir = 'data_exp/labels'  
BATCH_SIZE = 2
CHAR_EMBEDDING_DIM = 10
CHAR_LSTM_OUTPUT = 20
CONTEXT_HIDDEN_DIM = 20
NUM_LABELS = 10
MAX_TOKEN_LENGTH = 15

train_dataset = MultiSentenceDataset(file_dir, 
                                     bert_model_name=bert_model_name, 
                                     uncased=True)

char_vocab_dim = train_dataset.chars_vocab_dim

word_model = WordBertEmbeddingModel(bert_model_name=bert_model_name)
word_embedding_dim = word_model.embeddings_dim

char_model = CharLSTMModel(vocab_dim=char_vocab_dim,
                           embedding_dim=CHAR_EMBEDDING_DIM,
                           output_dim=CHAR_LSTM_OUTPUT,                           
                           )

context_model = ContextLSTMModel(input_dim=word_model.embeddings_dim + CHAR_LSTM_OUTPUT*MAX_TOKEN_LENGTH, 
                                 hidden_dim=CONTEXT_HIDDEN_DIM,
                                 output_dim=NUM_LABELS)

In [1245]:
seq_sampler = torch.utils.data.sampler.SequentialSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                              sampler=seq_sampler, collate_fn=sentence_data_collate_fn)

for index, batch in enumerate(train_dataloader):
    print('Index: ', index)
    print("Token IDs Shape:", batch['token_ids'].shape)
    print("Token Attention Mask Shape:", batch['token_mask'].shape)
    print("Word IDs Shape:", batch['word_ids'].shape)
    print("Text mask Shape:", batch['text_mask'].shape)
    print("Sentence mask Shape:", batch['sentence_mask'].shape)    
    print("Sentence length Shape:", batch['sentence_length'].shape)    
    print("Word chars Shape:", batch['word_char_ids'].shape)
    print("Word lengths Shape:", batch['word_lengths'].shape)

    word_embeddings = word_model(batch)          
    print("Embedded IDs Shape:", word_embeddings.shape)
    
    char_embeddings = char_model(batch)
    print("char_embeddings Shape:", char_embeddings.shape)
    
    word_char_embeddings = torch.cat((word_embeddings, char_embeddings), 2)
    print("word_char_embeddings Shape:", word_char_embeddings.shape)
    
    context_embeddings = context_model(word_char_embeddings, batch['sentence_length'])
    print("context_embeddings Shape:", context_embeddings.shape)
    
    if 'labels' in batch:
        labels = batch['labels']
        print("Labels Shape:", labels.shape)

Index:  0
Token IDs Shape: torch.Size([2, 9])
Token Attention Mask Shape: torch.Size([2, 9])
Word IDs Shape: torch.Size([2, 9])
Text mask Shape: torch.Size([2])
Sentence mask Shape: torch.Size([2])
Sentence length Shape: torch.Size([2])
Word chars Shape: torch.Size([2, 9, 15])
Word lengths Shape: torch.Size([2, 9])
Embedded IDs Shape: torch.Size([2, 9, 768])
char_embeddings Shape: torch.Size([2, 9, 300])
word_char_embeddings Shape: torch.Size([2, 9, 1068])
context_embeddings Shape: torch.Size([2, 1068, 10])
Index:  1
Token IDs Shape: torch.Size([2, 10])
Token Attention Mask Shape: torch.Size([2, 10])
Word IDs Shape: torch.Size([2, 10])
Text mask Shape: torch.Size([2])
Sentence mask Shape: torch.Size([2])
Sentence length Shape: torch.Size([2])
Word chars Shape: torch.Size([2, 10, 15])
Word lengths Shape: torch.Size([2, 10])
Embedded IDs Shape: torch.Size([2, 10, 768])
char_embeddings Shape: torch.Size([2, 10, 300])
word_char_embeddings Shape: torch.Size([2, 10, 1068])
context_embeddings

emissions (output of a BiLSTM or other sequence encoder) 
https://towardsdatascience.com/implementing-a-linear-chain-conditional-random-field-crf-in-pytorch-16b0b9c4b4ea


In [ ]:
class CRF(nn.Module):
    def __init__(self, nb_labels, cls_tag_id, sep_tag_id, pad_tag_id):
        super(CRF, self).__init__()    
        self.nb_labels = nb_labels
        self.CLS_TAG_ID = cls_tag_id
        self.SEP_TAG_ID = sep_tag_id
        self.PAD_TAG_ID = pad_tag_id
        self.transitions = nn.Parameter(torch.empty(self.nb_labels, self.nb_labels))
        self.init_weights()

    def init_weights(self):
        # initialize transitions from a random uniform distribution between -0.1 and 0.1
        nn.init.uniform_(self.transitions, -0.1, 0.1)
        # no transitions allowed to the beginning or from the end of sentence, from or to padding
        self.transitions.data[:, self.CLS_TAG_ID] = -10000.0
        self.transitions.data[self.SEP_TAG_ID, :] = -10000.0
        self.transitions.data[self.PAD_TAG_ID, :] = -10000.0
        self.transitions.data[:, :] = -10000.0
        # or we are already in a pad position
        self.transitions.data[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0
        self.transitions.data[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0
        
    def forward(self, emissions, labels, mask=None):
        """Compute the negative log-likelihood. See `log_likelihood` method."""
        neg_likelihood = -self.log_likelihood(emissions, labels, mask=mask)
        return neg_likelihood

    def log_likelihood(self, emissions, labels, mask=None):
        """Compute the probability of a sequence of tags given a sequence of
        emissions scores.
        Args:
            emissions (torch.Tensor): Sequence of emissions for each label.
                Shape of (batch_size, seq_len, nb_labels)
            labels (torch.LongTensor): Sequence of labels.
                Shape of (batch_size, seq_len).
            mask (torch.FloatTensor, optional): Tensor representing valid positions.
                If None, all positions are considered valid.
                Shape of (batch_size, seq_len).
        Returns:
            torch.Tensor: the (summed) log-likelihoods of each sequence in the batch.
                Shape of (1,)
        """  
        if mask is None:
            mask = torch.ones(emissions.shape[:2], dtype=torch.float)
    
        scores = self._compute_scores(emissions, labels, mask=mask)
        partition = self._compute_log_partition(emissions, mask=mask)
        return torch.sum(scores - partition)

    def _compute_scores(self, emissions, labels, mask):
        """Compute the scores for a given batch of emissions with their tags.
        Args:
            emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
            labels (Torch.LongTensor): (batch_size, seq_len)
            mask (Torch.FloatTensor): (batch_size, seq_len)
        Returns:
            torch.Tensor: Scores for each batch.
                Shape of (batch_size,)
        """
        batch_size, seq_length = labels.shape
        scores = torch.zeros(batch_size)
    
        # save first and last tags to be used later
        first_tags = labels[:, 0]
        last_valid_idx = mask.int().sum(1) - 1
        last_tags = labels.gather(1, last_valid_idx.unsqueeze(1)).squeeze()
    
        # add the transition from BOS to the first tags for each batch
        t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
    
        # add the [unary] emission scores for the first tags for each batch
        # for all batches, the first word, see the correspondent emissions
        # for the first tags (which is a list of ids):
        # emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
        e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
    
        # the scores for a word is just the sum of both scores
        scores += e_scores + t_scores
    
        # now lets do this for each remaining word
        for i in range(1, seq_length):
    
            # we could: iterate over batches, check if we reached a mask symbol
            # and stop the iteration, but vecotrizing is faster due to gpu,
            # so instead we perform an element-wise multiplication
            is_valid = mask[:, i]
    
            previous_tags = labels[:, i - 1]
            current_tags = labels[:, i]
    
            # calculate emission and transition scores as we did before
            e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
            t_scores = self.transitions[previous_tags, current_tags]
    
            # apply the mask
            e_scores = e_scores * is_valid
            t_scores = t_scores * is_valid
    
            scores += e_scores + t_scores
    
        # add the transition from the end tag to the EOS tag for each batch
        scores += self.transitions[last_tags, self.EOS_TAG_ID]
    
        return scores

    def _compute_log_partition(self, emissions, mask):
        """Compute the partition function in log-space using the forward-algorithm.
        Args:
            emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
            mask (Torch.FloatTensor): (batch_size, seq_len)
        Returns:
            torch.Tensor: the partition scores for each batch.
                Shape of (batch_size,)
        """
        batch_size, seq_length, nb_labels = emissions.shape
    
        # in the first iteration, BOS will have all the scores
        alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0]
    
        for i in range(1, seq_length):
            alpha_t = []
    
            for tag in range(nb_labels):
    
                # get the emission for the current tag
                e_scores = emissions[:, i, tag]
    
                # broadcast emission to all labels
                # since it will be the same for all previous tags
                # (bs, nb_labels)
                e_scores = e_scores.unsqueeze(1)
    
                # transitions from something to our tag
                t_scores = self.transitions[:, tag]
    
                # broadcast the transition scores to all batches
                # (bs, nb_labels)
                t_scores = t_scores.unsqueeze(0)
    
                # combine current scores with previous alphas
                # since alphas are in log space (see logsumexp below),
                # we add them instead of multiplying
                scores = e_scores + t_scores + alphas
    
                # add the new alphas for the current tag
                alpha_t.append(torch.logsumexp(scores, dim=1))
    
            # create a torch matrix from alpha_t
            # (bs, nb_labels)
            new_alphas = torch.stack(alpha_t).t()
    
            # set alphas if the mask is valid, otherwise keep the current values
            is_valid = mask[:, i].unsqueeze(-1)
            alphas = is_valid * new_alphas + (1 - is_valid) * alphas
    
        # add the scores for the final transition
        last_transition = self.transitions[:, self.EOS_TAG_ID]
        end_scores = alphas + last_transition.unsqueeze(0)
    
        # return a *log* of sums of exps
        return torch.logsumexp(end_scores, dim=1)

    def _viterbi_decode(self, emissions, mask):
        """Compute the viterbi algorithm to find the most probable sequence of labels
        given a sequence of emissions.

        Args:
            emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
            mask (Torch.FloatTensor): (batch_size, seq_len)

        Returns:
            torch.Tensor: the viterbi score for the for each batch.
                Shape of (batch_size,)
            list of lists of ints: the best viterbi sequence of labels for each batch
        """
        batch_size, seq_length, nb_labels = emissions.shape

        # in the first iteration, BOS will have all the scores and then, the max
        alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(0) + emissions[:, 0]

        backpointers = []

        for i in range(1, seq_length):
            alpha_t = []
            backpointers_t = []

            for tag in range(nb_labels):

                # get the emission for the current tag and broadcast to all labels
                e_scores = emissions[:, i, tag]
                e_scores = e_scores.unsqueeze(1)

                # transitions from something to our tag and broadcast to all batches
                t_scores = self.transitions[:, tag]
                t_scores = t_scores.unsqueeze(0)

                # combine current scores with previous alphas
                scores = e_scores + t_scores + alphas

                # so far is exactly like the forward algorithm,
                # but now, instead of calculating the logsumexp,
                # we will find the highest score and the tag associated with it
                max_score, max_score_tag = torch.max(scores, dim=-1)

                # add the max score for the current tag
                alpha_t.append(max_score)

                # add the max_score_tag for our list of backpointers
                backpointers_t.append(max_score_tag)

            # create a torch matrix from alpha_t
            # (bs, nb_labels)
            new_alphas = torch.stack(alpha_t).t()

            # set alphas if the mask is valid, otherwise keep the current values
            is_valid = mask[:, i].unsqueeze(-1)
            alphas = is_valid * new_alphas + (1 - is_valid) * alphas

            # append the new backpointers
            backpointers.append(backpointers_t)

        # add the scores for the final transition
        last_transition = self.transitions[:, self.EOS_TAG_ID]
        end_scores = alphas + last_transition.unsqueeze(0)

        # get the final most probable score and the final most probable tag
        max_final_scores, max_final_tags = torch.max(end_scores, dim=1)

        # find the best sequence of labels for each sample in the batch
        best_sequences = []
        emission_lengths = mask.int().sum(dim=1)
        for i in range(batch_size):

            # recover the original sentence length for the i-th sample in the batch
            sample_length = emission_lengths[i].item()

            # recover the max tag for the last timestep
            sample_final_tag = max_final_tags[i].item()

            # limit the backpointers until the last but one
            # since the last corresponds to the sample_final_tag
            sample_backpointers = backpointers[: sample_length - 1]

            # follow the backpointers to build the sequence of labels
            sample_path = self._find_best_path(i, sample_final_tag, sample_backpointers)

            # add this path to the list of best sequences
            best_sequences.append(sample_path)

        return max_final_scores, best_sequences

    def _find_best_path(self, sample_id, best_tag, backpointers):
        """Auxiliary function to find the best path sequence for a specific sample.

            Args:
                sample_id (int): sample index in the range [0, batch_size)
                best_tag (int): tag which maximizes the final score
                backpointers (list of lists of tensors): list of pointers with
                shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i
                represents the length of the ith sample in the batch

            Returns:
                list of ints: a list of tag indexes representing the bast path
        """

        # add the final best_tag to our best path
        best_path = [best_tag]

        # traverse the backpointers in backwards
        for backpointers_t in reversed(backpointers):

            # recover the best_tag at this timestep
            best_tag = backpointers_t[best_tag][sample_id].item()

            # append to the beginning of the list so we don't need to reverse it later
            best_path.insert(0, best_tag)

        return best_path

In [2]:
class BiLISTM_CRF(nn.Module):
    def __init__(self, 
                 tag_to_ix, hidden_dim,
                 vocab_size, word_embedding_dim, pre_word_embeds=None,                  
                 char_to_ix=None, char_out_dimension=25, char_embedding_dim=25, char_num_layers=1,
                 use_gpu=False, use_char=True, use_crf=True,
                 dropout_rate=0.5):
        """
        Input parameters:
                
                vocab_size= Size of vocabulary (int)
                tag_to_ix = Dictionary that maps NER tags to indices
                word_embedding_dim = Dimension of word embeddings (int)
                hidden_dim = The hidden dimension of the LSTM layer (int)
                char_to_ix = Dictionary that maps characters to indices
                pre_word_embeds = Numpy array which provides mapping from word embeddings to word indices
                char_out_dimension = Output dimension from the encoder for character
                char_embedding_dim = Dimension of the character embeddings
                use_gpu = defines availability of GPU, 
                    when True: CUDA function calls are made
                    else: Normal CPU function calls are made
                use_crf = parameter which decides if you want to use the CRF layer for output decoding
        """
        super().__init__()
        
        self.use_gpu = use_gpu
        #self.word_embedding_dim = word_embedding_dim
        #self.hidden_dim = hidden_dim
        #self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.use_crf = use_crf
        self.use_char = use_char
        self.char_out_dimension = char_out_dimension
        
        if self.use_char:
            self.char_embeds = nn.Embedding(len(self.char_to_ix), char_embedding_dim)            
            self.char_lstm = nn.LSTM(char_embedding_dim, char_out_dimension, num_layers=char_num_layers, bidirectional=True)
    
        self.word_embeds = nn.Embedding(vocab_size, word_embedding_dim)
        
         ## ? get existent ??
        if pre_word_embeds is not None:
            #Initializes the word embeddings with pretrained word embeddings
            self.pre_word_embeds = True
            self.word_embeds.weight = nn.Parameter(torch.FloatTensor(pre_word_embeds))
        else:
            self.pre_word_embeds = False
        
        #Drop out layer 
        self.dropout = nn.Dropout(dropout_rate)
        #BiLSTM for concatenated embeddings layer 
        self.lstm = nn.LSTM(word_embedding_dim + (char_out_dimension * 2 if use_char else 0), 
                            hidden_dim, bidirectional=True)
        
        # Linear layer to maps the output of the biLSTM into tag space
        self.hidden2tag = nn.Linear(hidden_dim*2, len(tag_to_ix))
        
        # Initialize the matrix of transition parameters between entities
        if self.use_crf:            
            self.transitions = nn.Parameter(torch.zeros(len(tag_to_ix), len(tag_to_ix)))            
            # Never transfer to the start tag or from the end tag
            self.transitions.data[tag_to_ix[START_TAG], :] = -10000
            self.transitions.data[:, tag_to_ix[END_TAG]] = -10000

    
    def _get_lstm_features(self, sentence, characters, characters_lengths, d):
        
        #get characters embeddings
        chars_embeds = self.char_embeds(characters).transpose(0, 1)
        #prepaire for LSTM to avoid padding
        packed = torch.nn.utils.rnn.pack_padded_sequence(chars_embeds, characters_lengths)
        
        lstm_out, _ = self.char_lstm(packed)
        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
        outputs = outputs.transpose(0, 1)
            
        chars_embeds_temp = torch.zeros((outputs.size(0), outputs.size(2)), dtype=torch.float32)
            
        if self.use_gpu:
            chars_embeds_temp = chars_embeds_temp.cuda()
        # cut paddings             
        for i, index in enumerate(output_lengths):
            chars_embeds_temp[i] = torch.cat((outputs[i, index-1, :self.char_lstm_dim], outputs[i, 0, self.char_lstm_dim:]))
        
        chars_embeds = chars_embeds_temp.clone()
        
        for i in range(chars_embeds.size(0)):
            chars_embeds[d[i]] = chars_embeds_temp[i]

    
        ## Loading word embeddings
        embeds = self.word_embeds(sentence)
    
        ## We concatenate the word embeddings and the character level representation
        ## to create unified representation for each word
        embeds = torch.cat((embeds, chars_embeds), 1)
    
        embeds = embeds.unsqueeze(1)
    
        ## Dropout on the unified embeddings
        embeds = self.dropout(embeds)
    
        ## Word lstm
        ## Takes words as input and generates a output at each step
        lstm_out, _ = self.lstm(embeds)
    
        ## Reshaping the outputs from the lstm layer
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim*2)
    
        ## Dropout on the lstm output
        lstm_out = self.dropout(lstm_out)
    
        ## Linear layer converts the ouput vectors to tag space
        lstm_feats = self.hidden2tag(lstm_out)
        
        return lstm_feats
    
    def viterbi_algo(self, feats):
        '''
        In this function, we implement the viterbi algorithm explained above.
        A Dynamic programming based approach to find the best tag sequence
        '''
        backpointers = []
        # analogous to forward
        
        # Initialize the viterbi variables in log space
        init_vvars = torch.Tensor(1, self.tagset_size).fill_(-10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0
        
        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars
        if self.use_gpu:
            forward_var = forward_var.cuda()
        for feat in feats:
            next_tag_var = forward_var.view(1, -1).expand(self.tagset_size, self.tagset_size) + self.transitions
            _, bptrs_t = torch.max(next_tag_var, dim=1)
            bptrs_t = bptrs_t.squeeze().data.cpu().numpy() # holds the backpointers for this step
            next_tag_var = next_tag_var.data.cpu().numpy() 
            viterbivars_t = next_tag_var[range(len(bptrs_t)), bptrs_t] # holds the viterbi variables for this step
            viterbivars_t = Variable(torch.FloatTensor(viterbivars_t))
            if self.use_gpu:
                viterbivars_t = viterbivars_t.cuda()
                
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            forward_var = viterbivars_t + feat
            backpointers.append(bptrs_t)
    
        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        terminal_var.data[self.tag_to_ix[STOP_TAG]] = -10000.
        terminal_var.data[self.tag_to_ix[START_TAG]] = -10000.
        best_tag_id = argmax(terminal_var.unsqueeze(0))
        path_score = terminal_var[best_tag_id]
        
        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
            
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG] # Sanity check
        best_path.reverse()
        return path_score, best_path
    
    def forward(self, sentence, chars, chars2_length, d):
        feats = self._get_lstm_features(sentence, chars, chars2_length, d)
        
        # Find the best path, given the features.
        if self.use_crf:
            score, tag_seq = self.viterbi_decode(feats)
        else:
            score, tag_seq = torch.max(feats, 1)
            tag_seq = list(tag_seq.cpu().data)
    
        return score, tag_seq