In [1]:
from sklearn_crfsuite import CRF

In [2]:
from DataProcess.model_utils import sent2features

# 定义模型

In [3]:
class CRFModel:
    def __init__(self, algorithm='lbfgs', 
                c1=0.1, c2=0.1, max_iterations=100,
                all_possible_transitions=False):
        self.model = CRF(algorithm=algorithm,
                        c1=c1, c2=c2,
                        max_iterations=max_iterations,
                        all_possible_transitions=all_possible_transitions)
    
    def train(self, sentences, tag_lists):
        features = [sent2features(s) for s in sentences]
        self.model.fit(features, tag_lists)
    
    def test(self, sentences):
        features = [sent2features(s) for s in sentences]
        pred_tag_lists = self.model.predict(features)
        return pred_tag_lists

# 读取数据

In [4]:
from DataProcess.data import build_corpus

In [5]:
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train")
dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)

# 训练模型

In [7]:
from DataProcess.utils import save_model
from DataProcess.evaluating import Metrics

In [8]:
def crf_train_eval(train_data, test_data, remove_O=False):
    # 训练CRF模型
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data
    
    crf_model = CRFModel()
    crf_model.train(train_word_lists, train_tag_lists)
    
    save_model(crf_model, "./ckpts/crf.pkl")
    pred_tag_lists = crf_model.test(test_word_lists)
    
    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()
    return pred_tag_lists

In [9]:
crf_pred = crf_train_eval((train_word_lists, train_tag_lists), 
                         (test_word_lists, test_tag_lists))

           precision    recall  f1-score   support
    M-LOC     1.0000    0.8095    0.8947        21
    E-EDU     0.9910    0.9821    0.9865       112
        O     0.9630    0.9732    0.9681      5190
    B-LOC     1.0000    0.8333    0.9091         6
  B-TITLE     0.9376    0.9339    0.9358       772
   B-CONT     1.0000    1.0000    1.0000        28
    B-EDU     0.9820    0.9732    0.9776       112
   B-RACE     1.0000    1.0000    1.0000        14
    B-ORG     0.9636    0.9566    0.9601       553
   M-CONT     1.0000    1.0000    1.0000        53
    B-PRO     0.9091    0.9091    0.9091        33
    M-ORG     0.9523    0.9563    0.9543      4325
   E-RACE     1.0000    1.0000    1.0000        14
  E-TITLE     0.9857    0.9819    0.9838       772
    M-EDU     0.9824    0.9330    0.9570       179
   M-NAME     1.0000    0.9756    0.9877        82
   B-NAME     1.0000    0.9821    0.9910       112
    E-PRO     0.9091    0.9091    0.9091        33
    M-PRO     0.8354    0.9706 