## Installations

In [1]:
! pip install transformers
!pip install sklearn
!pip install torch
!pip install tqdm



## Imports

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from torch.optim import SGD
from tqdm import tqdm
from transformers.models.bert.modeling_bert import BertForTokenClassification
import pickle
import re

## Functions

In [3]:
def align_word_ids(texts):
    '''
    str -> list
    Returns label_ids corresponding to the tokens in the sentence
    
    Params:
        texts (str) the sentence
    
    Returns:
        label_ids (list) a list of label_ids
    '''
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=MAX_LENGTH)

    word_ids = tokenized_inputs.word_ids()

    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:

        if word_idx is None:
            label_ids.append(-100)

        elif word_idx != previous_word_idx:
            try:
                label_ids.append(1)
            except:
                label_ids.append(-100)
        else:
            label_ids.append(-100)
        previous_word_idx = word_idx

    return label_ids


def predict(model, sentence):
    '''
    model, str -> list
    returns the tagged list corresponding to the tokenized
    sentence
    
    Params:
        model (torch.nn.Module) the fine-tuned tagger trained on
        manually annotated comments from espncricinfo
        sentence (str) a single comment for which the tags are to be
        predicted
    
    Returns:
        prediction_label (list) IOB tagged list 
    '''
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if use_cuda:
        model = model.cuda()

    text = tokenizer(sentence, padding='max_length', max_length = MAX_LENGTH, return_tensors="pt")

    mask = text['attention_mask'][0].unsqueeze(0).to(device)

    input_id = text['input_ids'][0].unsqueeze(0).to(device)
    label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)

    logits = model(input_id, mask, None)
    logits_clean = logits[0][label_ids != -100]

    predictions = logits_clean.argmax(dim=1).tolist()
    prediction_label = [ids2tags[i] for i in predictions]
    return prediction_label

In [4]:
def get_tokens(comm):
    '''
    str -> list
    Returns tokens after removing special tokens and attention masks
    
    Params:
        comm (str) comment for which the tokens are required
    
    Returns:
        tokens (list) list of tokens
    '''
    token_list = tokenizer.convert_ids_to_tokens(tokenizer.encode(comm, padding='max_length', max_length=MAX_LENGTH))
    tokens = []
    pattern = re.compile(r'^#.*')
    while True:
        for ind, token in enumerate(token_list):
            if token == '[CLS]' or token == '[SEP]' or token == '[PAD]':
                continue
            else:
                if (not ind == (len(token_list) - 1)) and ('#' in token_list[ind+1]):
                    tokens.append(joiner(token, token_list[ind+1].lstrip('#')))
                else:
                    if '#' in token:
                        continue
                    else:
                        tokens.append(token)
        if any((match := pattern.match(item)) for item in tokens):
            token_list = tokens.copy()
            tokens.clear()
        else:
            break
    return tokens

In [5]:
def get_line(tag_list, comm):
    '''
    tag_list, comm -> str/None
    returns the line of the bowl mentioned in the comment
    
    Params:
        tag_list (list) the tagged representation of comment 
        comm (str) the comment
    
    Returns:
        line (str) If there is a line tagged
        None (None) If there is no line tagged
    '''
    tokens = get_tokens(comm)
    assert len(tokens) == len(tag_list)
    line = ''
    for idx, tag in enumerate(tag_list):
        if tag == 'B-LINE' or tag == 'I-LINE':
            line += tokens[idx]
            for i in range(idx+1, len(tag_list)):
                if tag_list[i] == 'I-LINE' or tag_list[i] == 'B-LINE':
                    line += ' ' + tokens[i]
            return line
        else:
            continue
    return None

In [6]:
def get_length(tag_list, comm):
    '''
    tag_list, comm -> str/None
    returns the length of the bowl mentioned in the comment
    
    Params:
        tag_list (list) the tagged representation of comment 
        comm (str) the comment
    
    Returns:
        line (str) If there is a length tagged
        None (None) If there is no length tagged
    '''
    tokens = get_tokens(comm)
    assert len(tokens) == len(tag_list)
    length = ''
    for idx, tag in enumerate(tag_list):
        if tag == 'B-LENGTH' or tag == 'I-LENGTH':
            length += tokens[idx]
            for i in range(idx+1, len(tag_list)):
                if tag_list[i] == 'I-LENGTH' or tag_list[i] == 'B-LENGTH':
                    length += ' ' + tokens[i]
            return length
        else:
            continue
    return None

In [7]:
def joiner(this, next_this):
    '''
    removes attention mask and joins tokens
    
    Params:
        this (str) first token
        next_this (str) second token
        
    Returns:
        this + next_this (str) joined tokens 
    '''
    for i, char in enumerate(next_this[::-1]):
        if char == this[-1]:
            if char == next_this[0]:
                return this.rstrip(char) + next_this
            elif this[-2] == next_this[::-1][i+1]:
                return this.rstrip(next_this[::-1][i:][::-1]) + next_this
            else:
                return this + next_this

    return this + next_this

## Model

In [8]:
class BertModel(torch.nn.Module):
    def __init__(self):
        super(BertModel, self).__init__()
        self.bert = BertForTokenClassification.from_pretrained(model_dir, num_labels=len(unique_tags))

    def forward(self, input_id, mask, label):
        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
        return output

## Loading intermediate variables

In [9]:
with open('intermediate_files/unique_tags.pkl', 'rb') as handle:
    unique_tags = pickle.load(handle)

In [10]:
with open('intermediate_files/ids2tags.pkl', 'rb') as handle:
    ids2tags = pickle.load(handle)

In [11]:
with open('intermediate_files/tags2ids.pkl', 'rb') as handle:
    tags2ids = pickle.load(handle)

In [12]:
with open('intermediate_files/max_length.pkl', 'rb') as handle:
    MAX_LENGTH = pickle.load(handle)

## Loading model

In [13]:
model_dir = 'models/'
model_path = 'models/tagger.pt'
model = BertModel()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

Some weights of the model checkpoint at models/ were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at models/ and are newly initialized: ['classifier.w

<All keys matched successfully>

## Loading pretrained tokenizer

In [14]:
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained(model_dir)

## Model demonstration

### Comment

In [15]:
comm = "no-nonsense Short of a length at middle stump front leg in the leg side and an old fashioned slog across the line one bounce over the long-on boundary"

### Model (tagger) prediction

In [16]:
predict(model, comm)

['O',
 'O',
 'O',
 'B-LENGTH',
 'I-LENGTH',
 'I-LENGTH',
 'I-LENGTH',
 'O',
 'B-LINE',
 'I-LINE',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O']

### Extract line of the delivery from the comment using the tagger prediction

In [17]:
get_line(predict(model, comm), comm)

'middle stump'

### Extract length of the delivery from the comment using the tagger prediction

In [18]:
get_length(predict(model, comm), comm)

'short of a length'