In [49]:
from dataset_loader import PunctuationDataset
import pandas as pd
import torch
import training_params
from tqdm import tqdm
from seqeval.metrics import f1_score, accuracy_score
from sklearn import metrics
from transformers import AlbertForTokenClassification, AdamW, get_linear_schedule_with_warmup
import numpy as np

In [10]:
def process_data(data_csv):
    df = pd.read_csv(data_csv)
    sentences = df.groupby("sentence")["word"].apply(list).values
    labels = df.groupby("sentence")["label"].apply(list).values
    tag_values = list(set(df["label"].values))
    tag_values.append("PAD")
    encoder = {t: i for i, t in enumerate(tag_values)}
    return sentences, labels, encoder, tag_values

In [11]:
train_sentences, train_labels, train_encoder, tag_values = process_data(training_params.TRAIN_DATA)
valid_sentences, valid_labels, _, _ = process_data(training_params.VALID_DATA)

train_dataset = PunctuationDataset(texts=train_sentences, labels=train_labels,
                                   tag2idx=train_encoder)
valid_dataset = PunctuationDataset(texts=valid_sentences, labels=valid_labels,
                                   tag2idx=train_encoder)

train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=training_params.BATCH_SIZE, num_workers=4)
valid_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=training_params.BATCH_SIZE, num_workers=4)

model = AlbertForTokenClassification.from_pretrained('ai4bharat/indic-bert',
                                                     num_labels=len(train_encoder),
                                                     output_attentions=False,
                                                     output_hidden_states=False)

Some weights of the model checkpoint at ai4bharat/indic-bert were not used when initializing AlbertForTokenClassification: ['predictions.bias', 'predictions.LayerNorm.weight', 'predictions.LayerNorm.bias', 'predictions.dense.weight', 'predictions.dense.bias', 'predictions.decoder.weight', 'predictions.decoder.bias', 'sop_classifier.classifier.weight', 'sop_classifier.classifier.bias']
- This IS expected if you are initializing AlbertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForTokenClassification were not initialized from the model checkpoint at ai4bharat/indic-bert and a

In [12]:
eg = next(iter(train_data_loader))

In [14]:
b_input_ids, b_input_mask, b_labels = eg['ids'], eg['mask'], eg['target_tag']

In [15]:
outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)

In [21]:
logits = outputs[1].detach().cpu().numpy()
label_ids = b_labels.to('cpu').numpy()

In [22]:
logits.shape

(2, 128, 6)

In [23]:
label_ids.shape

(2, 128)

In [27]:
outputs[0]

tensor(1.7043, grad_fn=<NllLossBackward>)

In [28]:
predictions, true_labels = [], []

In [29]:
predictions.extend([list(p) for p in np.argmax(logits, axis=2)])

In [36]:
true_labels.extend(label_ids)

In [37]:
true_labels

[array([4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]),
 array([4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])]

In [38]:
pred_tags = [tag_values[p_i] for p, l in zip(predictions, true_labels) for p_i, l_i in zip(p, l) if tag_values[l_i] != "PAD"]

In [41]:
pred_tags

['comma',
 'PAD',
 'viram',
 'none',
 'none',
 'none',
 'none',
 'none ',
 'comma',
 'none ',
 'comma',
 'PAD',
 'viram',
 'none',
 'none',
 'none',
 'none',
 'none']

In [42]:
valid_tags = [tag_values[l_i] for l in true_labels for l_i in l if tag_values[l_i] != "PAD"]

In [43]:
valid_tags

['none',
 'none',
 'none',
 'none',
 'none',
 'none',
 'none',
 'none',
 'none',
 'viram',
 'none',
 'none',
 'none',
 'none',
 'none',
 'none',
 'none',
 'viram']

In [51]:
print("Validation Accuracy: {}".format(metrics.accuracy_score(pred_tags, valid_tags)))

Validation Accuracy: 0.4444444444444444


In [59]:
print("Validation F1-Score: {}".format(metrics.f1_score(pred_tags, valid_tags, average='macro')))

Validation F1-Score: 0.128


In [58]:
metrics.classification_report(pred_tags, valid_tags)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


'              precision    recall  f1-score   support\n\n         PAD       0.00      0.00      0.00         2\n       comma       0.00      0.00      0.00         3\n        none       0.50      0.89      0.64         9\n       none        0.00      0.00      0.00         2\n       viram       0.00      0.00      0.00         2\n\n    accuracy                           0.44        18\n   macro avg       0.10      0.18      0.13        18\nweighted avg       0.25      0.44      0.32        18\n'

[2, 5, 0, 4, 4, 4, 4, 1, 2, 1, 2, 5, 0, 4, 4, 4, 4, 4]