In [1]:
from model import HREBCRF
from utils import tokenize, compute_metrics, WeightLoggerCallback
from datasets import load_dataset
from transformers import BertTokenizerFast, Trainer, TrainingArguments

In [2]:
dataset_name = 'PassbyGrocer/msra-ner'
bert_model = 'hfl/chinese-roberta-wwm-ext-large'
max_length = 64

In [None]:
dataset= load_dataset(dataset_name)
train_dataset = dataset['train']
test_dataset = dataset['test']
val_dataset = dataset['validation']
num_labels=len(train_dataset.features["ner_tags"].feature.names)
model = HREBCRF.from_pretrained(bert_model, num_labels=num_labels)
tokenizer = BertTokenizerFast.from_pretrained(bert_model)

In [None]:
train_dataset = train_dataset.rename_column('ner_tags', 'label_ids')
test_dataset = test_dataset.rename_column('ner_tags', 'label_ids')

train_dataset = train_dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=True, batch_size=len(train_dataset))
test_dataset = test_dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=True, batch_size=len(test_dataset))

train_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label_ids'])
test_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label_ids'])

In [5]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=10,
    weight_decay=0.01,
    eval_strategy="epoch", 
)

In [None]:

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=lambda x: compute_metrics(x,dataset),
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    callbacks=[WeightLoggerCallback()]
)
trainer.train()