In [51]:
import argparse
import datasets
import math
import os
import time
import torch
import allennlp.modules.conditional_random_field as crf
###
from torch.utils.data import DataLoader
from tqdm import tqdm
from data import Conll2003, UNK, PAD
from util import pad_batch, pad_test_batch, count_parameters, calculate_epoch_time, build_mappings
from model import BiLSTM_CRF
from typing import Dict

In [52]:
def load_data():
    conll_dataset = datasets.load_dataset('conll2003')
    train_dataset = conll_dataset['train']
    valid_dataset = conll_dataset['validation']
    test_dataset = conll_dataset['test']
    return train_dataset, valid_dataset, test_dataset

In [53]:
def get_device() -> torch.device:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return device

In [54]:
def train_model(model, dataloader, optimizer, clip:int) -> float:
    model.train()
    epoch_loss = 0
    with tqdm(dataloader, unit='batch') as tqdm_loader:
        for x_padded, x_lens, y_padded in tqdm_loader:
            optimizer.zero_grad()
            result = model(x_padded, x_lens, y_padded, decode=False)
            neg_log_likelihood = result['loss']
            neg_log_likelihood.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            epoch_loss += neg_log_likelihood.item()
    return epoch_loss/len(dataloader.dataset)

In [55]:
def evaluate_model(model, dataloader) -> float:
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        with tqdm(dataloader, unit='batch') as tqdm_loader:
            for x_padded, x_lens, y_padded in tqdm_loader:
                result = model(x_padded, x_lens, y_padded, decode=False)
                neg_log_likelihood = result['loss']
                epoch_loss += neg_log_likelihood.item()
    return epoch_loss/len(dataloader.dataset)

In [56]:
def test_eval(test_data, model, batch_size, idx_to_tokens, tokens_to_idx, idx_to_tags):
    with torch.no_grad():
        predictions = []
        for batch_idx in range(len(test_data) // batch_size):
            batch = test_data.select(range(
                batch_size * batch_idx,
                batch_size * (batch_idx + 1)
            ))
            gold_labels = batch['ner_tags']
            tokens = batch['tokens']
            encoded_tokens = []
            for token_seq in tokens:
                encoded_seq = []
                for token in token_seq:
                    if token in tokens_to_idx:
                        encoded_seq.append(tokens_to_idx[token])
                    else:
                        encoded_seq.append(tokens_to_idx[UNK])
                encoded_tokens.append(torch.LongTensor(encoded_seq))
            batch_predictions = decode_batch(model, encoded_tokens, idx_to_tags=idx_to_tags)
            print('comp: ', gold_labels, batch_predictions)
            break
    # change when doing actual calcs
    return 0.0

In [57]:
def decode_batch(model, batch, idx_to_tags:Dict[int, str]):
    model.eval()
    with torch.no_grad():
        padded_batch = pad_test_batch(batch)
        x_padded, x_lens = padded_batch
        result = model(x_padded, x_lens, None, decode=True)
        actual_pred_tags = []
        for pred, _ in result['tags']:
            actual_pred_tags.append(pred)
            # actual_pred_tags.append([idx_to_tags[i] for i in pred])
    return actual_pred_tags

In [58]:
train, val, test = load_data()
ner_tags = train.features['ner_tags'].feature.names
device = get_device()

Reusing dataset conll2003 (/Users/sabhyachhabria/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)


  0%|          | 0/3 [00:00<?, ?it/s]

In [59]:
tokens_to_idx, idx_to_tokens = build_mappings(train['tokens'])

In [60]:
train_data = Conll2003(
        examples=train['tokens'][:1000], labels=train['ner_tags'][:1000],
        ner_tags=ner_tags, idx_to_tokens=idx_to_tokens, tokens_to_idx=tokens_to_idx,
        device=device
    )
val_data = Conll2003(
        examples=val['tokens'][:100], labels=val['ner_tags'][:100],
        ner_tags=ner_tags, idx_to_tokens=idx_to_tokens, tokens_to_idx=tokens_to_idx,
        device=device
    )

In [61]:
train_dataloader = DataLoader(dataset=train_data, batch_size=16, shuffle=True, collate_fn=pad_batch)
val_dataloader = DataLoader(dataset=val_data, batch_size=16, shuffle=True, collate_fn=pad_batch)

In [62]:
crf_constraints = crf.allowed_transitions(
        constraint_type='BIO',
        labels=train_data.idx_to_tags
)

In [63]:
bilstm_crf = BiLSTM_CRF(
        device=device,
        vocab_size=len(idx_to_tokens.keys()),
        num_tags=len(train_data.idx_to_tags.keys()),
        embedding_dim=50,
        lstm_hidden_dim=256,
        lstm_num_layers=1,
        dropout=0.2,
        constraints=crf_constraints,
        pad_idx=train_data.tokens_to_idx[PAD]
    )
bilstm_crf.to(device)

BiLSTM_CRF(
  (embeddings): Embedding(23624, 50)
  (lstm): LSTM(50, 128, batch_first=True, bidirectional=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=256, out_features=9, bias=True)
  (crf): ConditionalRandomField()
)

In [64]:
num_params = count_parameters(bilstm_crf)
print(f'The model has {num_params:,} trainable parameters')

The model has 1,367,932 trainable parameters


In [65]:
optimizer = torch.optim.Adam(bilstm_crf.parameters())

In [66]:
train_loss = train_model(
    model=bilstm_crf, dataloader=train_dataloader, optimizer=optimizer, clip=1)

100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [00:03<00:00, 20.73batch/s]


In [67]:
val_loss = evaluate_model(model=bilstm_crf, dataloader=val_dataloader)

100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 85.43batch/s]


In [70]:
test_f1 = test_eval(test_data=test, model=bilstm_crf, batch_size=4,
                    idx_to_tokens=idx_to_tokens, tokens_to_idx=tokens_to_idx, idx_to_tags=train_data.idx_to_tags)

comp:  [[0, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0, 0], [1, 2], [5, 0, 5, 6, 6, 0], [5, 0, 0, 0, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0]] [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [5, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
