Skip to content

Commit

Permalink
add classifier_pytorch ocnli
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaojinglu committed May 10, 2021
1 parent 42961f9 commit 3d8e09a
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Expand Up @@ -15,4 +15,7 @@ prev_trained_model
.idea
.idea/CLUE.iml
.idea/misc.iml
__pycache__
baselines/models_pytorch/classifier_pytorch/CLUEdatasets
baselines/models_pytorch/classifier_pytorch/*output
baselines/models/bert/*output
__pycache__
2 changes: 2 additions & 0 deletions baselines/models/bert/run_classifier.py
Expand Up @@ -29,6 +29,7 @@
import tokenization
import tensorflow as tf
import sys
import pdb
sys.path.append('..')
from classifier_utils import *

Expand Down Expand Up @@ -772,6 +773,7 @@ def main(_):
num_warmup_steps = None
if FLAGS.do_train:
print("data_dir:", FLAGS.data_dir)
pdb.set_trace()
train_examples = processor.get_train_examples(FLAGS.data_dir)
num_train_steps = int(
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
Expand Down
Expand Up @@ -39,6 +39,8 @@ def compute_metrics(task_name, preds, labels):
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "cmnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "ocnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "iflytek":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wsc":
Expand Down
42 changes: 41 additions & 1 deletion baselines/models_pytorch/classifier_pytorch/processors/clue.py
Expand Up @@ -250,6 +250,41 @@ def _create_examples(self, lines, set_type):
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

class OcnliProcessor(DataProcessor):
"""Processor for the CMNLI data set (CLUE version)."""

def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_json(os.path.join(data_dir, "train.json")), "train")

def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_json(os.path.join(data_dir, "dev.json")), "dev")

def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_json(os.path.join(data_dir, "test.json")), "test")

def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]

def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = line["sentence1"]
text_b = line["sentence2"]
label = str(line["label"]) if set_type != 'test' else 'neutral'
if label.strip()=='-':
continue
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

class CmnliProcessor(DataProcessor):
"""Processor for the CMNLI data set (CLUE version)."""
Expand Down Expand Up @@ -280,7 +315,9 @@ def _create_examples(self, lines, set_type):
guid = "%s-%s" % (set_type, i)
text_a = line["sentence1"]
text_b = line["sentence2"]
label = str(line["gold_label"]) if set_type != 'test' else 'neutral'
label = str(line["label"]) if set_type != 'test' else 'neutral'
if label.strip()=='-':
continue
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
Expand Down Expand Up @@ -446,6 +483,7 @@ def _create_examples_version2(self, lines, set_type):
clue_tasks_num_labels = {
'iflytek': 119,
'cmnli': 3,
'ocnli': 3,
'afqmc': 2,
'csl': 2,
'wsc': 2,
Expand All @@ -457,6 +495,7 @@ def _create_examples_version2(self, lines, set_type):
'tnews': TnewsProcessor,
'iflytek': IflytekProcessor,
'cmnli': CmnliProcessor,
'ocnli': OcnliProcessor,
'afqmc': AfqmcProcessor,
'csl': CslProcessor,
'wsc': WscProcessor,
Expand All @@ -467,6 +506,7 @@ def _create_examples_version2(self, lines, set_type):
'tnews': "classification",
'iflytek': "classification",
'cmnli': "classification",
'ocnli': "classification",
'afqmc': "classification",
'csl': "classification",
'wsc': "classification",
Expand Down
@@ -0,0 +1,85 @@
#!/usr/bin/env bash
# @Author: bo.shi
# @Date: 2019-11-04 09:56:36
# @Last Modified by: bo.shi
# @Last Modified time: 2020-01-01 11:40:23

TASK_NAME="ocnli"
MODEL_NAME="bert-base-chinese"
CURRENT_DIR=$(cd -P -- "$(dirname -- "$0")" && pwd -P)
export CUDA_VISIBLE_DEVICES="0"
export BERT_PRETRAINED_MODELS_DIR=$CURRENT_DIR/prev_trained_model
export BERT_WWM_DIR=$BERT_PRETRAINED_MODELS_DIR/$MODEL_NAME
export GLUE_DATA_DIR=$CURRENT_DIR/CLUEdatasets

# download and unzip dataset
if [ ! -d $GLUE_DATA_DIR ]; then
mkdir -p $GLUE_DATA_DIR
echo "makedir $GLUE_DATA_DIR"
fi
cd $GLUE_DATA_DIR
if [ ! -d $TASK_NAME ]; then
mkdir $TASK_NAME
echo "makedir $GLUE_DATA_DIR/$TASK_NAME"
fi
cd $TASK_NAME
if [ ! -f "train.json" ] || [ ! -f "dev.json" ] || [ ! -f "test.json" ]; then
rm *
wget https://storage.googleapis.com/cluebenchmark/tasks/cmnli_public.zip
unzip cmnli_public.zip
rm cmnli_public.zip
else
echo "data exists"
fi
echo "Finish download dataset."

# make output dir
if [ ! -d $CURRENT_DIR/${TASK_NAME}_output ]; then
mkdir -p $CURRENT_DIR/${TASK_NAME}_output
echo "makedir $CURRENT_DIR/${TASK_NAME}_output"
fi

# run task
cd $CURRENT_DIR
echo "Start running..."
if [ $# == 0 ]; then
python run_classifier.py \
--model_type=bert \
--model_name_or_path=$MODEL_NAME \
--task_name=$TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
--data_dir=$GLUE_DATA_DIR/${TASK_NAME}/ \
--max_seq_length=128 \
--per_gpu_train_batch_size=16 \
--per_gpu_eval_batch_size=16 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--logging_steps=24487 \
--save_steps=24487 \
--output_dir=$CURRENT_DIR/${TASK_NAME}_output/ \
--overwrite_output_dir \
--seed=42
elif [ $1 == "predict" ]; then
echo "Start predict..."
python run_classifier.py \
--model_type=bert \
--model_name_or_path=$MODEL_NAME \
--task_name=$TASK_NAME \
--do_predict \
--do_lower_case \
--data_dir=$GLUE_DATA_DIR/${TASK_NAME}/ \
--max_seq_length=128 \
--per_gpu_train_batch_size=16 \
--per_gpu_eval_batch_size=16 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--logging_steps=24487 \
--save_steps=24487 \
--output_dir=$CURRENT_DIR/${TASK_NAME}_output/ \
--overwrite_output_dir \
--seed=42
fi


0 comments on commit 3d8e09a

Please sign in to comment.