In [1]:
!pip install tensorboard



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import torch.nn.functional as F
from conlleval import evaluate as conllevaluate

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

In [5]:
def argmax(vec):
    _, idx = torch.max(vec, 1)
    return idx.item()


def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

In [None]:
class BiLSTM_CRF(nn.Module): 

    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim, char_embedding_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size))

        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
        self.hidden = self.init_hidden()
        self.char_embed = nn.Embedding(10, char_embedding_dim)

        
        self.char_cnn = nn.Conv2d(in_channels=1, out_channels=char_embedding_dim, kernel_size=(1, char_embedding_dim))

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2),
                torch.randn(2, 1, self.hidden_dim // 2))

    def _forward_alg(self, feats):
        
        init_alphas = torch.full((1, self.tagset_size), -10000.)
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

        forward_var = init_alphas
        device_info = forward_var.device 
        for feat in feats:
            alphas_t = []  
            for next_tag in range(self.tagset_size):
                emit_score = feat[next_tag].view(
                    1, -1).expand(1, self.tagset_size).to(device_info)
                trans_score = self.transitions[next_tag].view(1, -1).to(device_info)
                next_tag_var = forward_var + trans_score + emit_score
         
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]].to(device_info)
        alpha = log_sum_exp(terminal_var)
        return alpha

    def _get_lstm_features(self, sentence): 
        self.hidden = self.init_hidden()
        sentence = sentence.to(self.word_embeds.weight.device)
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        self.hidden = tuple(h.to(embeds.device) for h in self.hidden)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden) 
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats


    def get_char_indices(self, word_idx):
        """
        Extracts character indices using nltk.word_tokenize.
        """
        char_idx = [word_to_ix[char] for char in train_data[word_idx]['tokens']] 
        return char_idx

    def _get_lstm_features_cnn(self, sentence):
        self.hidden = self.init_hidden()

        sentence = sentence.to(self.word_embeds.weight.device)
        char_embeddings, char_ids = self.char_embed, []
        for word_idx in sentence:
           
            chars = self.get_char_indices(word_idx) 
            char_ids.append(torch.tensor(chars).to(device))
                            
        char_ids = pad_sequence(char_ids, batch_first=True, padding_value=0) 

        self.conv1 = nn.Conv1d(in_channels=char_embeddings.num_embeddings,  out_channels=11,  kernel_size=3, padding=1)  
        cnn_out = self.conv1(char_embeddings(char_ids))
        lstm_out = torch.max(F.relu(cnn_out), dim=2)[0] 
        lstm_out = lstm_out.view(len(sentence), -1)
        lstm_out, self.hidden = self.lstm(lstm_out, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim * 2)
        lstm_feats = self.hidden2tag(lstm_out)

        return lstm_feats

    def _score_sentence(self, feats, tags):
        score = torch.zeros(1, device=feats.device)
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        for i, feat in enumerate(feats):
            score = score + \
                self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []

        init_vvars = torch.full((1, self.tagset_size), -10000.,device=feats.device)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0

        forward_var = init_vvars
        for feat in feats:
            bptrs_t = [] 
            viterbivars_t = []  

            for next_tag in range(self.tagset_size):
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)

        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        feats = self._get_lstm_features(sentence)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score

    def forward(self, sentence):  
        lstm_feats = self._get_lstm_features(sentence)

        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

In [None]:
def make_data_point(sent):
    """
        Creates a dictionary from String to an Array of Strings representing the data.  
    """
    dic = {}
    sent = [s.strip().split() for s in sent]
    dic['tokens'] = ['<START>'] + [s[0] for s in sent] + ['<STOP>']
    dic['pos'] = ['<START>'] + [s[1] for s in sent] + ['<STOP>']
    dic['NP_chunk'] = ['<START>'] + [s[2] for s in sent] + ['<STOP>']
    dic['gold_tags'] = ['<START>'] + [s[3] for s in sent] + ['<STOP>']
    return dic

def read_data(filename):
    """
    Reads the CoNLL 2003 data into an array of dictionaries (a dictionary for each data point).
    """
    data = []
    with open(filename, 'r') as f:
        sent = []
        for line in f.readlines():
            if line.strip():
                sent.append(line)
            else:
                data.append(make_data_point(sent))
                sent = []
        data.append(make_data_point(sent))

    return data

In [8]:
train_data = read_data('ner.train')
print(train_data[0])

{'tokens': ['<START>', 'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.', '<STOP>'], 'pos': ['<START>', 'NNP', 'VBZ', 'JJ', 'NN', 'TO', 'VB', 'JJ', 'NN', '.', '<STOP>'], 'NP_chunk': ['<START>', 'I-NP', 'I-VP', 'I-NP', 'I-NP', 'I-VP', 'I-VP', 'I-NP', 'I-NP', 'O', '<STOP>'], 'gold_tags': ['<START>', 'I-ORG', 'O', 'I-MISC', 'O', 'O', 'O', 'I-MISC', 'O', 'O', '<STOP>']}


In [9]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"
EMBEDDING_DIM = 40
HIDDEN_DIM = 40
CHAR_EMBEDDING_DIM = 4

dev_data = read_data('ner.dev')
test_data = read_data('ner.test')
train_data = read_data('ner.train')
print(train_data[0])
print(dev_data[0])
print(test_data[0])

word_2_idx = {}
for sentence in train_data + dev_data + test_data:
    for word in sentence['tokens']:
        if word not in word_2_idx:
            word_2_idx[word] = len(word_2_idx)

tag_2_idx = {}
for sentence in train_data + dev_data + test_data:
    for word in sentence['gold_tags']:
        if word not in tag_2_idx:
            tag_2_idx[word] = len(tag_2_idx)
            

{'tokens': ['<START>', 'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.', '<STOP>'], 'pos': ['<START>', 'NNP', 'VBZ', 'JJ', 'NN', 'TO', 'VB', 'JJ', 'NN', '.', '<STOP>'], 'NP_chunk': ['<START>', 'I-NP', 'I-VP', 'I-NP', 'I-NP', 'I-VP', 'I-VP', 'I-NP', 'I-NP', 'O', '<STOP>'], 'gold_tags': ['<START>', 'I-ORG', 'O', 'I-MISC', 'O', 'O', 'O', 'I-MISC', 'O', 'O', '<STOP>']}
{'tokens': ['<START>', 'CRICKET', '-', 'LEICESTERSHIRE', 'TAKE', 'OVER', 'AT', 'TOP', 'AFTER', 'INNINGS', 'VICTORY', '.', '<STOP>'], 'pos': ['<START>', 'NNP', ':', 'NNP', 'NNP', 'IN', 'NNP', 'NNP', 'NNP', 'NNP', 'NN', '.', '<STOP>'], 'NP_chunk': ['<START>', 'I-NP', 'O', 'I-NP', 'I-NP', 'I-PP', 'I-NP', 'I-NP', 'I-NP', 'I-NP', 'I-NP', 'O', '<STOP>'], 'gold_tags': ['<START>', 'O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<STOP>']}
{'tokens': ['<START>', 'SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',', 'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.', '<STOP>'], 'pos': ['<START>', 'NN', ':', 'N

In [10]:
model = BiLSTM_CRF(len(word_2_idx), tag_2_idx, EMBEDDING_DIM, HIDDEN_DIM, CHAR_EMBEDDING_DIM)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
with torch.no_grad():
    precheck_sent = prepare_sequence(train_data[0]['tokens'], word_2_idx)
    precheck_tags = torch.tensor([tag_2_idx[t] for t in train_data[0]['gold_tags']], dtype=torch.long).to(device)
    print(model(precheck_sent))

(tensor(22.7877, device='cuda:0'), [1, 9, 9, 1, 9, 1, 9, 1, 9, 1, 9])


In [12]:
def generate_minibatches(training_data, batch_size):
    minibatches = []
    for i in range(0, len(training_data), batch_size):
        minibatch = training_data[i:i + batch_size]
        minibatches.append(minibatch)
    return minibatches

In [13]:
import time
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [19]:
batch_size = 128
start_time = time.time()
for epoch in range(5):  
    print(f"Epochs {epoch}")
    minibatches = generate_minibatches(train_data, batch_size)
    for i, minibatch in tqdm(enumerate(minibatches)):
        model.zero_grad()

        sentences_in = [prepare_sequence(sentence['tokens'], word_2_idx) for sentence in minibatch]
        sentences_in = pad_sequence(sentences_in, batch_first=True)
        targets = [torch.tensor([tag_2_idx[t] for t in sentence['gold_tags']], dtype=torch.long) for sentence in minibatch]
        targets = pad_sequence(targets, batch_first=True)

        loss = 0
        for sentence_in, target in zip(sentences_in, targets):
            loss += model.neg_log_likelihood(sentence_in, target)
        writer.add_scalar('Loss/train', loss.item(), epoch * len(minibatches) + i)

     
        loss.backward()
        optimizer.step()
    end_time = time.time()
    
    if epoch < 1:
        elapsed_time = end_time - start_time
        print(f"Training time for one epoch with batch size {batch_size}: {elapsed_time:.2f} seconds")
    torch.save(model.state_dict(), f'bilstm_crf_model_epoch_{epoch}.pth')


Epochs 0


118it [39:54, 20.29s/it]


Training time for one epoch with batch size 128: 2394.23 seconds
Epochs 1


118it [40:02, 20.36s/it]


Epochs 2


118it [39:45, 20.22s/it]


Epochs 3


118it [39:41, 20.18s/it]


Epochs 4


118it [39:41, 20.18s/it]


In [25]:
dev_seqs = [prepare_sequence(example['tokens'], word_2_idx) for example in dev_data]
test_seqs = [prepare_sequence(example['tokens'], word_2_idx) for example in test_data]


loaded_model = BiLSTM_CRF(len(word_2_idx), tag_2_idx, EMBEDDING_DIM, HIDDEN_DIM, CHAR_EMBEDDING_DIM)
loaded_model.load_state_dict(torch.load('bilstm_crf_model_epoch_4.pth'))
loaded_model.to(device)

tag_key_list = list(tag_2_idx.keys())
print(tag_key_list)

with open('dev_predictions.txt', 'w') as f_dev, open('test_predictions.txt', 'w') as f_test:
  with torch.no_grad():
    for dev_sent, test_sent in tqdm(zip(dev_seqs, test_seqs)):
    
      dev_sent = dev_sent.to(device)
      test_sent = test_sent.to(device)

      _ , dev_predicted_tags = loaded_model(dev_sent)
      _, test_predicted_tags = loaded_model(test_sent)

      dev_tags = [tag_key_list[tag_id] for tag_id in dev_predicted_tags]
      test_tags = [tag_key_list[tag_id] for tag_id in test_predicted_tags]

      f_dev.write(' '.join(dev_tags) + '\n')
      f_test.write(' '.join(test_tags) + '\n')

['<START>', 'I-ORG', 'O', 'I-MISC', '<STOP>', 'I-PER', 'I-LOC', 'B-LOC', 'B-MISC', 'B-ORG']


3466it [00:54, 63.79it/s] 


In [None]:
data = []
sentence = []
with open('dev_predictions.txt', 'r') as f:
    for line in f:
        line = line.strip()
        data.append(line.split(' '))

In [16]:
data[0]

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

In [17]:
gold = []
for example in dev_data:
    gold.extend(example['gold_tags'][1:-1])

In [18]:
val_data = []
for i in data:
    val_data.extend(i)

In [19]:

conllevaluate(gold, val_data)   

processed 51578 tokens with 5917 phrases; found: 1058 phrases; correct: 23.
accuracy:   0.56%; (non-O)
accuracy:  81.72%; precision:   2.17%; recall:   0.39%; FB1:   0.66
              LOC: precision:   0.00%; recall:   0.00%; FB1:   0.00  13
             MISC: precision:   2.31%; recall:   0.88%; FB1:   1.27  346
              ORG: precision:   0.00%; recall:   0.00%; FB1:   0.00  13
              PER: precision:   2.19%; recall:   0.82%; FB1:   1.19  686


(2.1739130434782608, 0.38871049518337, 0.6594982078853047)

In [20]:
data = []
sentence = []
with open('test_predictions.txt', 'r') as f:
    for line in f:
        line = line.strip()
        data.append(line.split(' '))

In [21]:
data[0]

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

In [22]:
gold = []
for example in dev_data:
    gold.extend(example['gold_tags'][1:-1])

In [23]:
val_data = []
for i in data:
    val_data.extend(i)

In [24]:

conllevaluate(gold, val_data)   

processed 51578 tokens with 5917 phrases; found: 898 phrases; correct: 15.
accuracy:   0.49%; (non-O)
accuracy:  81.97%; precision:   1.67%; recall:   0.25%; FB1:   0.44
              LOC: precision:   0.00%; recall:   0.00%; FB1:   0.00  7
             MISC: precision:   1.58%; recall:   0.55%; FB1:   0.81  316
              ORG: precision:   0.00%; recall:   0.00%; FB1:   0.00  7
              PER: precision:   1.76%; recall:   0.55%; FB1:   0.83  568


(1.670378619153675, 0.2535068446848065, 0.4402054292002935)