In [3]:
import json
%load_ext autoreload
%autoreload 2

In [4]:
input_path = 'data/train.json'
with open(input_path, 'r') as infile:   
    data = json.load(infile)

In [5]:
from itertools import product
tags = ['B', 'I']
pii_types = ['EMAIL', 'ID_NUM', 'NAME_STUDENT', 'PHONE_NUM', 'STREET_ADDRESS', 'URL_PERSONAL', 'USERNAME']

labels = ['O'] + [f'{tag}-{pii_type}' for pii_type, tag in product(pii_types, tags)]
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}

In [6]:
from transformers import AutoTokenizer

training_model_path = "microsoft/deberta-v3-large"
tokenizer = AutoTokenizer.from_pretrained(training_model_path)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import torch
sample = data[0]

def tokenize_and_align_labels(sample, tokeninzer, label2id):
    tokenized_input = tokenizer.encode_plus(
        sample['tokens'],
        is_split_into_words=True
    )

    target = []
    for word_id in tokenized_input.word_ids():
        if word_id is None:
            target.append(-100)
        else:
            target.append(label2id[sample['labels'][word_id]])
    # target = torch.tensor(target, dtype=torch.long)

    return tokenized_input, target

output = tokenize_and_align_labels(sample, tokenizer, label2id)

In [10]:
import numpy as np
def tokenize(example, tokenizer, label2id):
    text = []

    # these are at the character level
    labels = []

    for t, l, ws in zip(example["tokens"], example["labels"], example["trailing_whitespace"]):

        text.append(t)
        labels.extend([l]*len(t))

        # if there is trailing whitespace
        if ws:
            text.append(" ")
            labels.append("O")


    tokenized = tokenizer("".join(text), return_offsets_mapping=True, truncation=False)

    labels = np.array(labels)

    text = "".join(text)
    token_labels = []

    for start_idx, end_idx in tokenized.offset_mapping:

        # CLS token
        if start_idx + end_idx == 0: 
            token_labels.append(label2id["O"])
            continue

        # case when token starts with whitespace
        if text[start_idx].isspace():
            start_idx += 1
        
        while start_idx >= len(labels):
            start_idx -= 1

        token_labels.append(label2id[labels[start_idx]])

    length = len(tokenized.input_ids)

    return {
        **tokenized,
        "labels": token_labels,
        "length": length
    }

tokenize(sample, tokenizer, label2id)
print('hello')

{'input_ids': [1,
  2169,
  12103,
  270,
  3513,
  28310,
  4593,
  271,
  57498,
  24360,
  16789,
  271,
  1609,
  30065,
  12287,
  662,
  86260,
  6738,
  429,
  1857,
  279,
  1637,
  273,
  380,
  264,
  408,
  305,
  6998,
  1879,
  308,
  384,
  390,
  262,
  6870,
  265,
  266,
  663,
  269,
  262,
  791,
  2269,
  260,
  458,
  1444,
  269,
  266,
  791,
  2269,
  302,
  1663,
  264,
  262,
  3742,
  265,
  72791,
  1398,
  897,
  260,
  263,
  72791,
  1398,
  736,
  260,
  287,
  15724,
  261,
  10040,
  268,
  5152,
  271,
  92671,
  2531,
  280,
  51388,
  260,
  3045,
  294,
  9110,
  25247,
  42255,
  268,
  1931,
  280,
  65426,
  7933,
  260,
  285,
  261,
  262,
  791,
  2269,
  287,
  698,
  59729,
  6000,
  285,
  269,
  266,
  4981,
  5190,
  3395,
  272,
  3832,
  262,
  1008,
  7392,
  265,
  262,
  791,
  263,
  1279,
  262,
  1959,
  280,
  268,
  1068,
  264,
  282,
  1315,
  260,
  45110,
  30097,
  435,
  329,
  1637,
  303,
  386,
  5228,
  294,
  1795,
 