In [None]:

import os
os.environ["WANDB_DISABLED"] = "true"

import pickle
import numpy as np
from datasets import Dataset
from transformers import (
    BertTokenizerFast,
    BertForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
from seqeval.metrics import classification_report, accuracy_score, f1_score
TRAIN_PATH = "/content/train_pos_data.pkl"
TEST_PATH  = "/content/test_pos_data.pkl"

with open(TRAIN_PATH, "rb") as f:
    train_raw = pickle.load(f)

with open(TEST_PATH, "rb") as f:
    test_raw = pickle.load(f)

def normalize(samples):
    normalized = []
    for sentence in samples:
        tokens = [w for (w, t) in sentence]
        tags   = [t for (w, t) in sentence]
        normalized.append({"tokens": tokens, "tags": tags})
    return normalized

train_data = normalize(train_raw)
test_data  = normalize(test_raw)
tag_set = sorted({t for s in train_data for t in s["tags"]})
tag2id = {t:i for i,t in enumerate(tag_set)}
id2tag = {i:t for t,i in tag2id.items()}

print("Number of tags:", len(tag_set))
print(tag_set)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

def tokenize_and_align(batch):
    enc = tokenizer(batch["tokens"],
                    truncation=True,
                    padding=False,
                    is_split_into_words=True)

    aligned_labels = []

    for i in range(len(batch["tokens"])):
        word_ids = enc.word_ids(batch_index=i)
        labels = batch["tags"][i]
        aligned = []
        prev = None

        for wid in word_ids:
            if wid is None:
                aligned.append(-100)
            elif wid != prev:
                aligned.append(tag2id[labels[wid]])
            else:
                aligned.append(-100)
            prev = wid

        aligned_labels.append(aligned)

    enc["labels"] = aligned_labels
    return enc
train_ds = Dataset.from_list(train_data).map(tokenize_and_align, batched=True)
test_ds  = Dataset.from_list(test_data).map(tokenize_and_align, batched=True)
model = BertForTokenClassification.from_pretrained(
    "bert-base-cased",
    num_labels=len(tag_set),
    id2label=id2tag,
    label2id=tag2id
)
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    true_tags, pred_tags = [], []

    for i in range(len(labels)):
        t, p = [], []
        for j in range(len(labels[i])):
            if labels[i][j] != -100:
                t.append(id2tag[labels[i][j]])
                p.append(id2tag[preds[i][j]])
        true_tags.append(t)
        pred_tags.append(p)

    return {
        "accuracy": accuracy_score(true_tags, pred_tags),
        "f1_macro": f1_score(true_tags, pred_tags, average="macro")
    }
args = TrainingArguments(
    output_dir="bert-pos-model",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=50,
    logging_dir="logs",
    report_to="none"
)
data_collator = DataCollatorForTokenClassification(tokenizer)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
pred = trainer.predict(test_ds)
logits = pred.predictions
pred_ids = np.argmax(logits, axis=-1)

bert_true, bert_pred = [], []

for i in range(len(test_ds)):
    t, p = [], []
    for j in range(len(test_ds[i]["labels"])):
        if test_ds[i]["labels"][j] != -100:
            t.append(id2tag[test_ds[i]["labels"][j]])
            p.append(id2tag[pred_ids[i][j]])
    bert_true.append(t)
    bert_pred.append(p)
print("Accuracy:", accuracy_score(bert_true, bert_pred))
print("F1-macro:", f1_score(bert_true, bert_pred, average="macro"))
print("\nClassification Report:\n")
print(classification_report(bert_true, bert_pred))
def predict_pos(sentence):
    tokens = sentence.split()
    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt")
    outputs = model(**inputs)
    pred_ids = outputs.logits.argmax(-1).tolist()[0]
    word_ids = inputs.word_ids()

    results = []
    last_wid = None

    for pid, wid in zip(pred_ids, word_ids):
        if wid is not None and wid != last_wid:
            results.append((tokens[wid], id2tag[pid]))
            last_wid = wid

    return results

print("\nTEST INFERENCE:\n")
print(predict_pos("Aditya visited Mumbai yesterday"))


Number of tags: 46
['#', '$', "''", ',', '-LRB-', '-NONE-', '-RRB-', '.', ':', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', '``']


Map:   0%|          | 0/3131 [00:00<?, ? examples/s]

Map:   0%|          | 0/783 [00:00<?, ? examples/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss
50,1.627
100,0.3565
150,0.2753
200,0.2677
250,0.2004
300,0.1985
350,0.1842
400,0.1649
450,0.1421
500,0.1279







Accuracy: 0.9614903099804697
F1-macro: 0.8584489762645905

Classification Report:



  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           '       1.00      0.98      0.99       137
           B       0.96      0.94      0.95       926
          BD       0.96      0.95      0.96       620
          BG       0.93      0.94      0.94       289
          BN       0.92      0.94      0.93       389
          BP       0.98      0.97      0.97       242
          BR       0.76      0.67      0.71        24
          BS       0.89      1.00      0.94         8
          BZ       0.98      0.97      0.98       418
           C       0.98      0.91      0.95       447
           D       0.99      0.98      0.99       770
          DT       0.99      0.90      0.94        91
           J       0.92      0.89      0.91      1088
          JR       0.87      0.93      0.90        74
          JS       0.92      0.92      0.92        26
           N       0.95      0.94      0.94      3386
          NP       0.79      0.78      0.79      1183
         NPS       0.62    