Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,5 @@ m2r | |
|
||
# develop | ||
jupyterlab | ||
tabulate | ||
pandas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Author : BrikerMan | ||
# Site : https://eliyar.biz | ||
|
||
# Time : 2020/8/29 11:16 上午 | ||
# File : classifications.py | ||
# Project : Kashgari | ||
|
||
import logging | ||
import time | ||
from typing import Type | ||
|
||
import pandas as pd | ||
|
||
from kashgari.corpus import SMP2018ECDTCorpus | ||
from kashgari.embeddings import BertEmbedding | ||
from kashgari.tasks.classification import ABCClassificationModel | ||
from kashgari.tasks.classification import ALL_MODELS | ||
from test_performance.tools import get_bert_path | ||
|
||
|
||
class ClassificationPerformance: | ||
|
||
MODELS = ALL_MODELS | ||
|
||
def run_with_model_class(self, model_class: Type[ABCClassificationModel], epochs: int): | ||
bert_path = get_bert_path() | ||
|
||
train_x, train_y = SMP2018ECDTCorpus.load_data('train') | ||
valid_x, valid_y = SMP2018ECDTCorpus.load_data('valid') | ||
test_x, test_y = SMP2018ECDTCorpus.load_data('test') | ||
|
||
bert_embed = BertEmbedding(bert_path) | ||
model = model_class(bert_embed) | ||
model.fit(train_x, train_y, valid_x, valid_y, epochs=epochs) | ||
|
||
report = model.evaluate(test_x, test_y) | ||
del model | ||
del bert_embed | ||
return report | ||
|
||
def run(self, epochs=10): | ||
logging.basicConfig(level='DEBUG') | ||
reports = [] | ||
for model_class in self.MODELS: | ||
logging.info("="*80) | ||
logging.info("") | ||
logging.info("") | ||
logging.info(f" Start Training {model_class.__name__}") | ||
logging.info("") | ||
logging.info("") | ||
logging.info("=" * 80) | ||
start = time.time() | ||
report = self.run_with_model_class(model_class, epochs=epochs) | ||
time_cost = time.time() - start | ||
reports.append({ | ||
'model_name': model_class.__name__, | ||
"epoch": epochs, | ||
'f1-score': report['f1-score'], | ||
'precision': report['precision'], | ||
'recall': report['recall'], | ||
'time': f"{int(time_cost//60):02}:{int(time_cost%60):02}" | ||
}) | ||
|
||
df = pd.DataFrame(reports) | ||
print(df.to_markdown()) | ||
|
||
|
||
if __name__ == '__main__': | ||
p = ClassificationPerformance() | ||
p.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Author : BrikerMan | ||
# Site : https://eliyar.biz | ||
|
||
# Time : 2020/8/29 11:47 上午 | ||
# File : labeling.py | ||
# Project : Kashgari | ||
|
||
from typing import Type | ||
from kashgari.corpus import ChineseDailyNerCorpus | ||
from kashgari.embeddings import BertEmbedding | ||
from kashgari.tasks.labeling import ABCLabelingModel | ||
from kashgari.tasks.labeling import ALL_MODELS | ||
from test_performance.classifications import ClassificationPerformance | ||
from test_performance.tools import get_bert_path | ||
|
||
|
||
class LabelingPerformance(ClassificationPerformance): | ||
MODELS = ALL_MODELS | ||
|
||
def run_with_model_class(self, model_class: Type[ABCLabelingModel], epochs: int): | ||
bert_path = get_bert_path() | ||
|
||
train_x, train_y = ChineseDailyNerCorpus.load_data('train') | ||
valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') | ||
test_x, test_y = ChineseDailyNerCorpus.load_data('test') | ||
|
||
bert_embed = BertEmbedding(bert_path) | ||
model = model_class(bert_embed) | ||
model.fit(train_x, train_y, valid_x, valid_y, epochs=epochs) | ||
|
||
report = model.evaluate(test_x, test_y) | ||
del model | ||
del bert_embed | ||
return report | ||
|
||
|
||
if __name__ == '__main__': | ||
p = LabelingPerformance() | ||
p.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Performance | ||
|
||
This is for run performance report on models with bert-embedding. | ||
|
||
|
||
## Classification | ||
|
||
```python | ||
from kashgari.corpus import SMP2018ECDTCorpus | ||
|
||
train_x, train_y = SMP2018ECDTCorpus.load_data('train') | ||
valid_x, valid_y = SMP2018ECDTCorpus.load_data('valid') | ||
test_x, test_y = SMP2018ECDTCorpus.load_data('test') | ||
``` | ||
|
||
| | model_name | epoch | f1-score | precision | recall | time | | ||
|---:|:--------------------|--------:|-----------:|------------:|---------:|:-------| | ||
| 0 | BiGRU_Model | 10 | 0.9335 | 0.937795 | 0.935065 | 00:33 | | ||
| 1 | BiLSTM_Model | 10 | 0.929075 | 0.930548 | 0.92987 | 00:33 | | ||
| 2 | CNN_Attention_Model | 10 | 0.862197 | 0.888507 | 0.866234 | 00:27 | | ||
| 3 | CNN_GRU_Model | 10 | 0.840024 | 0.886519 | 0.850649 | 00:28 | | ||
| 4 | CNN_LSTM_Model | 10 | 0.424649 | 0.551247 | 0.511688 | 00:27 | | ||
| 5 | CNN_Model | 10 | 0.930336 | 0.938373 | 0.931169 | 00:26 | | ||
|
||
## NER | ||
|
||
```python | ||
from kashgari.corpus import ChineseDailyNerCorpus | ||
|
||
train_x, train_y = ChineseDailyNerCorpus.load_data('train') | ||
valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') | ||
test_x, test_y = ChineseDailyNerCorpus.load_data('test') | ||
``` | ||
|
||
| | model_name | epoch | f1-score | precision | recall | time | | ||
|---:|:---------------|--------:|-----------:|------------:|---------:|:-------| | ||
| 0 | BiGRU_Model | 10 | 0.917219 | 0.915018 | 0.919474 | 16:30 | | ||
| 1 | BiLSTM_Model | 10 | 0.918491 | 0.908189 | 0.929361 | 16:37 | | ||
| 2 | CNN_LSTM_Model | 10 | 0.925621 | 0.91963 | 0.932223 | 16:31 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Author : BrikerMan | ||
# Site : https://eliyar.biz | ||
|
||
# Time : 2020/8/29 11:11 上午 | ||
# File : tools.py | ||
# Project : Kashgari | ||
|
||
import os | ||
import zipfile | ||
import pathlib | ||
from tensorflow.keras.utils import get_file | ||
from kashgari import macros as K | ||
|
||
|
||
def get_bert_path() -> str: | ||
url = "https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip" | ||
bert_path = os.path.join(K.DATA_PATH, 'datasets', 'bert') | ||
model_path = os.path.join(bert_path, 'chinese_L-12_H-768_A-12') | ||
pathlib.Path(bert_path).mkdir(parents=True, exist_ok=True) | ||
if not os.path.exists(model_path): | ||
zip_file_path = get_file("bert/chinese_L-12_H-768_A-12.zip", | ||
url, | ||
cache_dir=K.DATA_PATH, ) | ||
|
||
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | ||
zip_ref.extractall(bert_path) | ||
return model_path | ||
|
||
|
||
if __name__ == '__main__': | ||
for k, v in os.environ.items(): | ||
print(f'{k:20}: {v}') | ||
get_bert_path() |