### Imports

In [1]:
import spacy
import checklist
from checklist.perturb import Perturb
from checklist.test_types import MFT, INV, DIR
from checklist.editor import Editor
from checklist.pred_wrapper import PredictorWrapper
import numpy as np
from pattern.en import sentiment
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
from datasets import Dataset
from sklearn.metrics import precision_recall_fscore_support

### Load data

In [20]:
def read_iob2_file(path):
    """
    read in conll file
    
    :param path: path to read from
    :returns: list with sequences of words and labels for each sentence
    """
    data = []
    current_words = []
    current_tags = []

    for line in open(path, encoding='utf-8'):
        line = line.strip()
        if line:
            if line[0] == '#':
                continue # skip comments
            tok = line.split('\t')

            current_words.append(tok[1])
            current_tags.append(tok[2])
        else:
            if current_words:  # skip empty lines
                data.append((current_words, current_tags))
            current_words = []
            current_tags = []

    # check for last one
    if current_tags != []:
        data.append((current_words, current_tags))
    return data

train_data= read_iob2_file('./en_ewt-ud-train.iob2')
dev_data = read_iob2_file('./en_ewt-ud-dev.iob2')

### Data preprocessing

In [21]:
# Hyperparameters
DIM_EMBEDDING = 100
LSTM_HIDDEN = 50
BATCH_SIZE = 64
LEARNING_RATE = 0.01
EPOCHS = 5
PAD = '<PAD>'

In [22]:
class Vocab():
    def __init__(self, pad_unk):
        """
        A convenience class that can help store a vocabulary
        and retrieve indices for inputs.
        """
        self.pad_unk = pad_unk
        self.word2idx = {self.pad_unk: 0}
        self.idx2word = [self.pad_unk]

    def getIdx(self, word, add=False):
        if word not in self.word2idx:
            if add:
                self.word2idx[word] = len(self.idx2word)
                self.idx2word.append(word)
            else:
                return self.word2idx[self.pad_unk]
        return self.word2idx[word]

    def getWord(self, idx):
        return self.idx2word[idx]


max_len = max([len(x[0]) for x in train_data ])

# Create vocabularies for both the tokens
# and the tags
token_vocab = Vocab(PAD)
label_vocab = Vocab(PAD)
id_to_token = [PAD]

for tokens, tags in train_data:
    for token in tokens:
        token_vocab.getIdx(token, True)
    for tag in tags:
        label_vocab.getIdx(tag, True)

NWORDS = len(token_vocab.idx2word)
NTAGS = len(label_vocab.idx2word)

# convert text data with labels to indices
def data2feats(inputData, word_vocab, label_vocab):
    feats = torch.zeros((len(inputData), max_len), dtype=torch.long)
    labels = torch.zeros((len(inputData), max_len), dtype=torch.long)
    for sentPos, sent in enumerate(inputData):
        for wordPos, word in enumerate(sent[0][:max_len]):
            wordIdx = word_vocab.getIdx(word)
            feats[sentPos][wordPos] = wordIdx
        for labelPos, label in enumerate(sent[1][:max_len]):
            labelIdx = label_vocab.getIdx(label)
            labels[sentPos][labelPos] = labelIdx
    return feats, labels

train_features, train_labels = data2feats(train_data, token_vocab, label_vocab)

### Convert to batches

In [5]:
num_batches = int(len(train_features)/BATCH_SIZE)
train_feats_batches = train_features[:BATCH_SIZE*num_batches].view(num_batches, BATCH_SIZE, max_len)
train_labels_batches = train_labels[:BATCH_SIZE*num_batches].view(num_batches, BATCH_SIZE, max_len)

In [6]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForTokenClassification.from_pretrained(model_name, num_labels=NTAGS)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    save_steps=100,
)

In [8]:
# Code from chat, need to modify it.
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  flat_labels = labels.view(-1)
  flat_preds = preds.view(-1)

  # Calculate metrics for each entity class (modify class names as needed)
  precision, recall, f1, _ = precision_recall_fscore_support(flat_labels, flat_preds, average=None, labels=[1, 2, 3])  # Replace class labels (e.g., PER, LOC, ORG)


  return {
      'precision': precision.tolist(),
      'recall': recall.tolist(),
      'f1': f1.tolist(),
  }

# Fine tuning BERT

### Length of training data = 12543
Please edit the **number_of_sentences** to the desired number

In [9]:
number_of_sentences = 12543
features = {'input_ids': train_features[:number_of_sentences], 'label': train_labels[:number_of_sentences]}
train_dataset = Dataset.from_dict(features)

###  Part below needs to be run only **once**! - if we have the desired model saved in finetuned_bert, then you can skip

~ 50s / it for 100 data

In [10]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    compute_metrics=compute_metrics, 
)

trainer.train()
trainer.save_model("finetuned_bert")

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


  0%|          | 0/35 [00:00<?, ?it/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


{'train_runtime': 1228.7705, 'train_samples_per_second': 0.407, 'train_steps_per_second': 0.028, 'train_loss': 0.3649191447666713, 'epoch': 5.0}


TrainOutput(global_step=35, training_loss=0.3649191447666713, metrics={'train_runtime': 1228.7705, 'train_samples_per_second': 0.407, 'train_steps_per_second': 0.028, 'train_loss': 0.3649191447666713, 'epoch': 5.0})

## Load fine tuned BERT

In [12]:
# Load our finetuned model
fine_tuned = BertForTokenClassification.from_pretrained("finetuned_bert")
# print(fine_tuned)

## Evaluate our model on dev data

In [13]:
sentences = []
predictions = []

def run_eval(feats_batches, labels_batches):
    match = 0
    total = 0
    for sents, labels in zip(feats_batches, labels_batches):
        output_scores = fine_tuned(sents)
        predicted_tags  = torch.argmax(output_scores.logits, dim=-1)
        for sentence in sents:
            sentenceWords = []
            for wordIndex in sentence:
                sentenceWords.append(token_vocab.getWord(wordIndex.item()))
            sentences.append(sentenceWords)
        for sentenceTags in predicted_tags:
                predictionTagOneSentence = []
                for tag in sentenceTags:
                    predictionTagOneSentence.append(label_vocab.idx2word[tag.item()])
                predictions.append(predictionTagOneSentence)
        for goldSent, predSent in zip(labels, predicted_tags):
            for goldLabel, predLabel in zip(goldSent, predSent):
                if goldLabel.item() != 0:
                    total += 1
                    if goldLabel.item() == predLabel.item():
                        match+= 1
    return(match/total)

In [15]:
dev_feats, dev_labels = data2feats(dev_data, token_vocab, label_vocab)
num_batches_dev = int(len(dev_feats)/BATCH_SIZE)

dev_feats_batches = dev_feats[:BATCH_SIZE*num_batches_dev].view(num_batches_dev, BATCH_SIZE, max_len)
dev_labels_batches = dev_labels[:BATCH_SIZE*num_batches_dev].view(num_batches_dev, BATCH_SIZE, max_len)
score = run_eval(dev_feats_batches[:50], dev_labels_batches[:50])

# Training on first 10 sentences gives 0.0061 accuracy for first 4 dev data.
# Training on first 100 sentences gives 0.7949 accuracy for first 4 dev data.
# Training on first 100 sentences gives 0.7544 accuracy for first 50 dev data.

print('Accuracy for dev data: {:.4f}'.format(score))

Accuracy for dev data: 0.7544
