In [1]:
import json
import os
import re
import bisect
from pathlib import Path

import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from spacy.lang.en import English
from transformers import DebertaV2ForTokenClassification, DebertaV2TokenizerFast
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers import Trainer
from transformers.training_args import TrainingArguments
from transformers.data.data_collator import DataCollatorForTokenClassification

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
INFERENCE_MAX_LENGTH = 3500
CONF_THRESH = 0.90  # threshold for "O" class
URL_THRESH = 0.1  # threshold for URL
AMP = False

In [3]:
with open('../data/train.json', 'r') as f:
    data = json.load(f)
data_df = pd.read_json('../data/train.json')

In [4]:
nlp = English()

def find_span(target: list[str], document: list[str]) -> list[list[int]]:
    idx = 0
    spans = []
    span = []

    for i, token in enumerate(document):
        if token != target[idx]:
            idx = 0
            span = []
            continue
        span.append(i)
        idx += 1
        if idx == len(target):
            spans.append(span)
            span = []
            idx = 0
            continue
    
    return spans

def spacy_to_hf(data: dict, idx: int) -> slice:
    """
    Given an index of spacy token, return corresponding indices in deberta's output.
    We use this to find indice of URL tokens later.
    """
    str_range = np.where(np.array(data["token_map"]) == idx)[0]
    start_idx = bisect.bisect_left([off[1] for off in data["offset_mapping"]], str_range.min())
    end_idx = start_idx
    while end_idx < len(data["offset_mapping"]):
        if str_range.max() > data["offset_mapping"][end_idx][1]:
            end_idx += 1
            continue
        break
    token_range = slice(start_idx, end_idx+1)
    return token_range

In [5]:
class CustomTokenizer:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, max_length: int) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, example: dict) -> dict:
        text = []
        token_map = []

        for idx, (t, ws) in enumerate(zip(example["tokens"], example["trailing_whitespace"])):
            text.append(t)
            token_map.extend([idx]*len(t))
            if ws:
                text.append(" ")
                token_map.append(-1)

        tokenized = self.tokenizer(
            "".join(text),
            return_offsets_mapping=True,
            truncation=True,
            max_length=self.max_length,
        )

        return {**tokenized,"token_map": token_map,}

In [7]:
MODEL_PATH = Path("../pre_trained_model/archive")

In [8]:
ds = Dataset.from_dict({
    "full_text": [x["full_text"] for x in data],
    "document": [x["document"] for x in data],
    "tokens": [x["tokens"] for x in data],
    "trailing_whitespace": [x["trailing_whitespace"] for x in data],
})

tokenizer = DebertaV2TokenizerFast.from_pretrained(MODEL_PATH)
ds = ds.map(CustomTokenizer(tokenizer=tokenizer, max_length=INFERENCE_MAX_LENGTH), num_proc=os.cpu_count())

Map (num_proc=8): 100%|██████████| 6807/6807 [06:03<00:00, 18.74 examples/s] 


In [11]:
model = DebertaV2ForTokenClassification.from_pretrained(MODEL_PATH)
collator = DataCollatorForTokenClassification(tokenizer)
args = TrainingArguments(".", per_device_eval_batch_size=1, report_to="none", fp16=AMP)
trainer = Trainer(
    model=model, args=args, data_collator=collator, tokenizer=tokenizer,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [14]:
selected_rows = ds.select(range(5))

In [18]:
predictions = trainer.predict(selected_rows).predictions

  7%|▋         | 508/6807 [3:07:32<38:45:21, 22.15s/it]


In [20]:
pred_softmax = torch.softmax(torch.from_numpy(predictions), dim=2).numpy()
id2label = model.config.id2label
o_index = model.config.label2id["O"]
preds = predictions.argmax(-1)
preds_without_o = pred_softmax.copy()
preds_without_o[:,:,o_index] = 0
preds_without_o = preds_without_o.argmax(-1)
o_preds = pred_softmax[:,:,o_index]
preds_final = np.where(o_preds < CONF_THRESH, preds_without_o , preds)

In [21]:
preds_final

array([[12, 12, 12, ...,  0,  0,  0],
       [12,  2,  8, ...,  0,  0,  0],
       [12, 12, 12, ...,  0,  0,  0],
       [12, 12, 12, ...,  0,  0,  0],
       [12, 12, 12, ..., 12, 12, 12]])

In [24]:
processed =[]
pairs = set()

# Iterate over document
for p, token_map, offsets, tokens, doc in zip(
    preds_final, selected_rows["token_map"], selected_rows["offset_mapping"], selected_rows["tokens"], selected_rows["document"]
):
    # Iterate over sequence
    for token_pred, (start_idx, end_idx) in zip(p, offsets):
        label_pred = id2label[token_pred]

        if start_idx + end_idx == 0:
            # [CLS] token i.e. BOS
            continue

        if token_map[start_idx] == -1:
            start_idx += 1

        # ignore "\n\n"
        while start_idx < len(token_map) and tokens[token_map[start_idx]].isspace():
            start_idx += 1

        if start_idx >= len(token_map): 
            break

        token_id = token_map[start_idx]
        pair = (doc, token_id)

        # ignore certain labels and whitespace
        # if label_pred in ("O", "B-EMAIL", "B-URL_PERSONAL", "B-PHONE_NUM", "I-PHONE_NUM") or token_id == -1:
        #     continue        

        if pair in pairs:
            continue
            
        processed.append(
            {"document": doc, "token": token_id, "label": label_pred, "token_str": tokens[token_id]}
        )
        pairs.add(pair)

In [25]:
processed

[{'document': 7, 'token': 0, 'label': 'O', 'token_str': 'Design'},
 {'document': 7, 'token': 1, 'label': 'O', 'token_str': 'Thinking'},
 {'document': 7, 'token': 2, 'label': 'O', 'token_str': 'for'},
 {'document': 7, 'token': 3, 'label': 'O', 'token_str': 'innovation'},
 {'document': 7, 'token': 4, 'label': 'O', 'token_str': 'reflexion'},
 {'document': 7, 'token': 5, 'label': 'O', 'token_str': '-'},
 {'document': 7, 'token': 6, 'label': 'O', 'token_str': 'Avril'},
 {'document': 7, 'token': 7, 'label': 'O', 'token_str': '2021'},
 {'document': 7, 'token': 8, 'label': 'O', 'token_str': '-'},
 {'document': 7,
  'token': 9,
  'label': 'B-NAME_STUDENT',
  'token_str': 'Nathalie'},
 {'document': 7, 'token': 10, 'label': 'I-NAME_STUDENT', 'token_str': 'Sylla'},
 {'document': 7, 'token': 12, 'label': 'O', 'token_str': 'Challenge'},
 {'document': 7, 'token': 13, 'label': 'O', 'token_str': '&'},
 {'document': 7, 'token': 14, 'label': 'O', 'token_str': 'selection'},
 {'document': 7, 'token': 16, '