Skip to content

Commit

Permalink
fix compute_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed May 8, 2023
1 parent 3a4fb2b commit 2c105fa
Showing 1 changed file with 65 additions and 42 deletions.
107 changes: 65 additions & 42 deletions examples/notebook_examples/msra_ner.ipynb
Expand Up @@ -103,26 +103,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WgM6aADvnIm5",
"outputId": "500364c0-909a-47ae-c785-686dc2eccb12"
},
"outputs": [
{
"data": {
"text/plain": [
"['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']"
]
},
"execution_count": 6,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"id": "WgM6aADvnIm5"
},
"outputs": [],
"source": [
"label_list = datasets[\"train\"].features[f\"{task}_tags\"].feature.names\n",
"label_list"
Expand Down Expand Up @@ -187,7 +170,33 @@
"import numpy as np\n",
"\n",
"def compute_metrics(p):\n",
" predictions, labels = p\n",
" print(p.__dict__)\n",
" predictions = p.predictions\n",
" labels = p.label_ids\n",
" predictions = np.argmax(predictions, axis=2)\n",
"\n",
" # Remove ignored index (special tokens)\n",
" true_predictions = [\n",
" [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
" true_labels = [\n",
" [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
"\n",
" results = metric.compute(predictions=true_predictions, references=true_labels)\n",
" return {\n",
" \"precision\": results[\"overall_precision\"],\n",
" \"recall\": results[\"overall_recall\"],\n",
" \"f1\": results[\"overall_f1\"],\n",
" \"accuracy\": results[\"overall_accuracy\"],\n",
" }\n",
"\n",
"def compute_eval_metrics(p):\n",
" print(p.__dict__)\n",
" predictions = p.predictions[0]\n",
" labels = p.label_ids\n",
" predictions = np.argmax(predictions, axis=2)\n",
"\n",
" # Remove ignored index (special tokens)\n",
Expand Down Expand Up @@ -306,6 +315,18 @@
"trainer.train()"
]
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"id": "XxdTV3biv2o1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -330,8 +351,8 @@
"cell_type": "code",
"source": [
"from torch.utils.data import DataLoader, RandomSampler\n",
"train_dataset=tokenized_datasets[\"train\"]\n",
"train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=32) #prepare dataloader"
"train_dataset=tokenized_datasets[\"train\"].remove_columns(['id','tokens','ner_tags'])\n",
"train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=32,collate_fn=data_collator) #prepare dataloader"
],
"metadata": {
"id": "eLs-39XvLCzp"
Expand All @@ -350,9 +371,10 @@
"import textbrewer\n",
"from textbrewer import GeneralDistiller\n",
"from textbrewer import TrainingConfig, DistillationConfig\n",
"from transformers import BertForTokenClassification, BertConfig, AdamW,BertTokenizer\n",
"from transformers import BertForTokenClassification, BertConfig,BertTokenizer\n",
"from transformers import get_linear_schedule_with_warmup\n",
"import torch "
"import torch \n",
"from torch.optim import AdamW"
]
},
{
Expand All @@ -374,13 +396,13 @@
},
"outputs": [],
"source": [
"bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config_L3.json') \n",
"bert_config_T3 = BertConfig.from_json_file('/content/bert_config_L3.json') \n",
"bert_config_T3.output_hidden_states = True\n",
"bert_config_T3.num_labels = len(label_list)\n",
"\n",
"student_model = BertForTokenClassification(bert_config_T3)\n",
"\n",
"bert_config = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config.json')\n",
"bert_config = BertConfig.from_json_file('/content/bert_config.json')\n",
"bert_config.output_hidden_states = True\n",
"bert_config.num_labels = len(label_list)\n",
"\n",
Expand All @@ -402,6 +424,18 @@
"After the code execution is complete, the distilled model will be in 'saved_model' in colab file list"
]
},
{
"cell_type": "code",
"source": [
"def proc_fn(batch):\n",
" return {'input_ids':batch['input_ids'],'token_type_ids':batch['token_type_ids'],'attention_mask':batch['attention_mask'],'labels':batch['labels']}"
],
"metadata": {
"id": "DLXYbGlb55Zm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -436,7 +470,8 @@
"\n",
"\n",
"with distiller:\n",
" distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)"
" distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None,\n",
" batch_postprocessor=proc_fn)"
]
},
{
Expand All @@ -456,7 +491,7 @@
},
"outputs": [],
"source": [
"bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/data/bert_config/bert_config_L3.json')\n",
"bert_config_T3 = BertConfig.from_json_file('/content/bert_config_L3.json')\n",
"\n",
"bert_config_T3.output_hidden_states = True\n",
"bert_config_T3.num_labels = len(label_list)\n",
Expand Down Expand Up @@ -504,7 +539,7 @@
" eval_dataset=tokenized_datasets[\"test\"],\n",
" data_collator=data_collator,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics\n",
" compute_metrics=compute_eval_metrics\n",
")"
]
},
Expand All @@ -518,23 +553,11 @@
"source": [
"trainer.evaluate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4hq_TkiP1-I6"
},
"outputs": [],
"source": [
""
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "msra.ipynb",
"provenance": [],
"include_colab_link": true
Expand Down

0 comments on commit 2c105fa

Please sign in to comment.