In [1]:
from model import HREBCRF
from utils import tokenize, compute_metrics, WeightLoggerCallback
from datasets import load_dataset
from transformers import BertTokenizerFast, Trainer, TrainingArguments

In [2]:
dataset_name = 'PassbyGrocer/msra-ner'
bert_model = 'hfl/chinese-roberta-wwm-ext-large'
max_length = 64

In [3]:
dataset= load_dataset(dataset_name)
train_dataset = dataset['train']
test_dataset = dataset['test']
val_dataset = dataset['validation']
num_labels=len(train_dataset.features["ner_tags"].feature.names)
model = HREBCRF.from_pretrained(bert_model, num_labels=num_labels)
tokenizer = BertTokenizerFast.from_pretrained(bert_model)

Some weights of HREBCRF were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext-large and are newly initialized: ['bilstm.bias_hh_l0', 'bilstm.bias_hh_l0_reverse', 'bilstm.bias_hh_l1', 'bilstm.bias_hh_l1_reverse', 'bilstm.bias_ih_l0', 'bilstm.bias_ih_l0_reverse', 'bilstm.bias_ih_l1', 'bilstm.bias_ih_l1_reverse', 'bilstm.weight_hh_l0', 'bilstm.weight_hh_l0_reverse', 'bilstm.weight_hh_l1', 'bilstm.weight_hh_l1_reverse', 'bilstm.weight_ih_l0', 'bilstm.weight_ih_l0_reverse', 'bilstm.weight_ih_l1', 'bilstm.weight_ih_l1_reverse', 'classifier.bias', 'classifier.weight', 'crf.end_transitions', 'crf.start_transitions', 'crf.transitions', 'layer_norm.bias', 'layer_norm.weight', 'mega.Uh', 'mega.Wh', 'mega.bh', 'mega.multi_headed_ema.alphas', 'mega.multi_headed_ema.dampen_factors', 'mega.multi_headed_ema.expansion', 'mega.multi_headed_ema.reduction', 'mega.single_headed_attn.offsetscale.beta', 'mega.single_headed_attn.offsetscale.gamma', 'mega.single_headed_attn.rel_pos_bias

In [4]:
train_dataset = train_dataset.rename_column('ner_tags', 'label_ids')
test_dataset = test_dataset.rename_column('ner_tags', 'label_ids')

train_dataset = train_dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=True, batch_size=len(train_dataset))
test_dataset = test_dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=True, batch_size=len(test_dataset))

train_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label_ids'])
test_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label_ids'])

Map:   0%|          | 0/46364 [00:00<?, ? examples/s]

Map:   0%|          | 0/4365 [00:00<?, ? examples/s]

In [5]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=10,
    weight_decay=0.01,
    eval_strategy="epoch", 
)

In [6]:

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=lambda x: compute_metrics(x,dataset),
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    callbacks=[WeightLoggerCallback()]
)
trainer.train()

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,44.4211,44.428574,0.942418,0.925364,0.933683
2,24.6181,40.110237,0.922253,0.914421,0.918228
3,17.0536,57.471779,0.924063,0.931646,0.927705
4,14.5891,76.849312,0.913823,0.928152,0.920597
5,11.7824,89.700432,0.931988,0.924812,0.928205
6,8.3881,74.115349,0.935111,0.937165,0.936002
7,4.009,77.435707,0.933294,0.931544,0.932287
8,3.3499,64.52845,0.934997,0.936859,0.935725
9,1.3022,66.708778,0.945213,0.940312,0.94267
10,1.1795,64.48259,0.941601,0.939403,0.940371


Current weight of:
r_lstm: 0.6216
r_mega: 0.6216

Current weight of:
r_lstm: 0.6212
r_mega: 0.6212

              precision    recall  f1-score   support

       B-LOC       0.98      0.95      0.97      2674
       I-LOC       0.98      0.92      0.95      4076
       B-ORG       0.90      0.92      0.91      1218
       I-ORG       0.94      0.93      0.93      5054
       B-PER       0.98      0.98      0.98      1304
       I-PER       0.98      0.98      0.98      2425

   micro avg       0.96      0.94      0.95     16751
   macro avg       0.96      0.95      0.95     16751
weighted avg       0.96      0.94      0.95     16751

Current weight of:
r_lstm: 0.6209
r_mega: 0.6209

Current weight of:
r_lstm: 0.6209
r_mega: 0.6209

Current weight of:
r_lstm: 0.6207
r_mega: 0.6207

Current weight of:
r_lstm: 0.6205
r_mega: 0.6205

              precision    recall  f1-score   support

       B-LOC       0.98      0.95      0.97      2674
       I-LOC       0.96      0.94      0.95     

TrainOutput(global_step=14490, training_loss=15.689221124010633, metrics={'train_runtime': 4054.121, 'train_samples_per_second': 114.363, 'train_steps_per_second': 3.574, 'total_flos': 5.661662440280064e+16, 'train_loss': 15.689221124010633, 'epoch': 10.0})