Skip to content

Commit

Permalink
Clean up and fix E2E NER baselines (#12)
Browse files Browse the repository at this point in the history
* add constants

* fix path

* clean up pkl & run black

* add build vp ner lm & fix text_ner path

* fix build_vp_ner_lm & remove dict
  • Loading branch information
fwu-asapp committed Feb 15, 2022
1 parent 0607622 commit 78d06f3
Show file tree
Hide file tree
Showing 22 changed files with 111 additions and 102 deletions.
2 changes: 1 addition & 1 deletion baselines/ner/configs/w2v2_ner_1gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ task:
_name: audio_finetuning
data: ???
normalize: false # must be consistent with pre-training
labels: ltr
labels: raw.ltr

dataset:
num_workers: 0
Expand Down
8 changes: 3 additions & 5 deletions baselines/ner/e2e_scripts/ft-w2v2-base.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ pretrained_ckpt=`realpath $pretrained_ckpt`

config_dir=baselines/ner/configs
config=w2v2_ner_1gpu
label_type=raw
train_subset=fine-tune_$label_type
valid_subset=dev_$label_type
python slue_toolkit/prepare/create_dict.py manifest/slue-voxpopuli/e2e_ner/fine-tune_$label_type.ltr manifest/slue-voxpopuli/e2e_ner/dict.ltr.txt
python slue_toolkit/prepare/create_dict.py manifest/slue-voxpopuli/e2e_ner/fine-tune_$label_type.wrd manifest/slue-voxpopuli/e2e_ner/dict.wrd.txt
train_subset=fine-tune
valid_subset=dev

normalize=false
lr=5e-5
Expand All @@ -38,6 +35,7 @@ fairseq-hydra-train \
hydra.output_subdir=$save \
common.tensorboard_logdir=$tb_save \
task.data=$data \
task.labels="raw.ltr" \
dataset.train_subset=$train_subset \
dataset.valid_subset=$valid_subset \
distributed_training.distributed_world_size=$ngpu \
Expand Down
8 changes: 3 additions & 5 deletions baselines/ner/e2e_scripts/ft-w2v2-large.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ pretrained_ckpt=`realpath $pretrained_ckpt`

config_dir=baselines/ner/configs
config=w2v2_ner_1gpu
label_type=raw
train_subset=fine-tune_$label_type
valid_subset=dev_$label_type
python slue_toolkit/prepare/create_dict.py manifest/slue-voxpopuli/e2e_ner/fine-tune_$label_type.ltr manifest/slue-voxpopuli/e2e_ner/dict.ltr.txt
python slue_toolkit/prepare/create_dict.py manifest/slue-voxpopuli/e2e_ner/fine-tune_$label_type.wrd manifest/slue-voxpopuli/e2e_ner/dict.wrd.txt
train_subset=fine-tune
valid_subset=dev

normalize=true
lr=1e-5
Expand All @@ -38,6 +35,7 @@ fairseq-hydra-train \
hydra.output_subdir=$save \
common.tensorboard_logdir=$tb_save \
task.data=$data \
task.labels="raw.ltr" \
dataset.train_subset=$train_subset \
dataset.valid_subset=$valid_subset \
distributed_training.distributed_world_size=$ngpu \
Expand Down
1 change: 1 addition & 0 deletions scripts/build_vp_ner_lm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bash scripts/create_ngram.sh $1 manifest/slue-voxpopuli/e2e_ner/fine-tune.raw.wrd save/kenlm/vp_ner 4
7 changes: 0 additions & 7 deletions scripts/download_datasets.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,3 @@ for label in ltr wrd; do
python slue_toolkit/prepare/create_dict.py manifest/slue-${split}/fine-tune.${label} manifest/slue-${split}/dict.${label}.txt
done
done

#5. copy files
for session in dev fine-tune; do
for label in raw combined; do
cp ./manifest/slue-voxpopuli/${session}.tsv ./manifest/slue-voxpopuli/e2e_ner/${session}_${label}.tsv
done
done
2 changes: 1 addition & 1 deletion slue_toolkit/eval/eval_w2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def eval_asr(
# eval_log_file = os.path.join(ckpt, 'eval.log')
eval_log_file = None
if save_results:
results_path = os.path.join(model, "decode")
results_path = os.path.join(model, "decode", lm.replace('/', '_'))
else:
results_path = None
emission_path = (
Expand Down
17 changes: 7 additions & 10 deletions slue_toolkit/eval/eval_w2v_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
sys.path.insert(0, "../")

from slue_toolkit.eval import eval_utils
from slue_toolkit.generic_utils import read_lst, save_pkl, load_pkl
from slue_toolkit.generic_utils import (
read_lst,
save_pkl,
spl_char_to_entity,
raw_to_combined_tag_map,
)


def make_distinct(label_lst):
Expand All @@ -32,22 +37,14 @@ def get_gt_pred(score_type, eval_label, eval_set, decoded_data_dir):
"""
Read the GT and predicted utterances in the entity format [(word1, tag1), (word2, tag2), ...]
"""
spl_char_to_entity = load_pkl(
os.path.join("slue_toolkit/prepare/files/", "spl_char_to_entity.pkl")
)
entity_end_char = "]"
entity_to_spl_char = {}
for spl_char, entity in spl_char_to_entity.items():
entity_to_spl_char[entity] = spl_char

if eval_label == "combined":
label_map_dct = load_pkl(
os.path.join("slue_toolkit/prepare/files/", "raw_to_combined_tags.pkl")
)

def update_label_lst(lst, phrase, label):
if eval_label == "combined":
label = label_map_dct[label]
label = raw_to_combined_tag_map[label]
if label != "DISCARD":
if score_type == "label":
lst.append((label, "phrase"))
Expand Down
5 changes: 1 addition & 4 deletions slue_toolkit/eval/infer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ def add_asr_eval_argument(parser):
help="temperature scaling of the logits",
)
parser.add_argument(
"--eval-upsample",
type=float,
default=1.0,
help="upsample factor",
"--eval-upsample", type=float, default=1.0, help="upsample factor",
)
return parser

Expand Down
4 changes: 1 addition & 3 deletions slue_toolkit/fairseq_addon/data/add_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@

class AddLabelDataset(BaseWrapperDataset):
def __init__(
self,
dataset,
labels,
self, dataset, labels,
):
super().__init__(dataset)
self.labels = labels
Expand Down
8 changes: 2 additions & 6 deletions slue_toolkit/fairseq_addon/tasks/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class AudioClassificationTask(AudioPretrainingTask):
cfg: AudioClassificationConfig

def __init__(
self,
cfg: AudioClassificationConfig,
self, cfg: AudioClassificationConfig,
):
super().__init__(cfg)
self.blank_symbol = "<s>"
Expand Down Expand Up @@ -76,10 +75,7 @@ def load_dataset(
f"({len(self.datasets[split])}) do not match"
)

self.datasets[split] = AddLabelDataset(
self.datasets[split],
labels,
)
self.datasets[split] = AddLabelDataset(self.datasets[split], labels,)

@property
def label2id(self):
Expand Down
53 changes: 53 additions & 0 deletions slue_toolkit/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,59 @@

import pickle as pkl

raw_entity_to_spl_char = {
"CARDINAL": "!",
"DATE": "@",
"EVENT": "#",
"FAC": "$",
"GPE": "%",
"LANGUAGE": "^",
"LAW": "&",
"LOC": "*",
"MONEY": "(",
"NORP": ")",
"ORDINAL": "~",
"ORG": "`",
"PERCENT": "{",
"PERSON": "}",
"PRODUCT": "[",
"QUANTITY": "<",
"TIME": ">",
"WORK_OF_ART": "?",
}
spl_char_to_entity = {v: k for k, v in raw_entity_to_spl_char.items()}

raw_to_combined_tag_map = {
"DATE": "WHEN",
"TIME": "WHEN",
"CARDINAL": "QUANT",
"ORDINAL": "QUANT",
"QUANTITY": "QUANT",
"MONEY": "QUANT",
"PERCENT": "QUANT",
"GPE": "PLACE",
"LOC": "PLACE",
"NORP": "NORP",
"ORG": "ORG",
"LAW": "LAW",
"PERSON": "PERSON",
"FAC": "DISCARD",
"EVENT": "DISCARD",
"WORK_OF_ART": "DISCARD",
"PRODUCT": "DISCARD",
"LANGUAGE": "DISCARD",
}

combined_entity_to_spl_char = {
"LAW": "!",
"NORP": "@",
"ORG": "#",
"PERSON": "$",
"PLACE": "%",
"QUANT": "^",
"WHEN": "&",
}


def save_pkl(fname, dict_name):
with open(fname, "wb") as f:
Expand Down
Binary file not shown.
47 changes: 0 additions & 47 deletions slue_toolkit/label_map_files/ner.dict.ltr.txt

This file was deleted.

Binary file removed slue_toolkit/label_map_files/raw_entity_to_spl_char.pkl
Binary file not shown.
Binary file removed slue_toolkit/label_map_files/raw_to_combined_tags.pkl
Binary file not shown.
Binary file removed slue_toolkit/label_map_files/spl_char_to_entity.pkl
Binary file not shown.
4 changes: 2 additions & 2 deletions slue_toolkit/prepare/create_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import fire


def main(input, output, thres=0):
def create_dict(input, output, thres=0):
counter = Counter()
with open(input) as f:
for line in f:
Expand All @@ -15,4 +15,4 @@ def main(input, output, thres=0):


if __name__ == "__main__":
fire.Fire(main)
fire.Fire(create_dict)
3 changes: 2 additions & 1 deletion slue_toolkit/prepare/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle as pkl
from slue_toolkit.generic_utils import raw_to_combined_tag_map


def load_pkl(fname, encdng=None):
Expand Down Expand Up @@ -28,7 +29,7 @@ def get_label_lst(label_str, label_type):
"""
if label_str == "None" or label_str == "[]":
return []
tag_map = load_pkl("slue_toolkit/label_map_files/raw_to_combined_tags.pkl")
tag_map = raw_to_combined_tag_map
label_lst = []
ner_labels_lst = label_str.strip("[[").strip("]]").split("], [")
for item in ner_labels_lst:
Expand Down
2 changes: 1 addition & 1 deletion slue_toolkit/prepare/prepare_voxceleb.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def create_manifest(
frames = int(16000 * (end - start))
print(f"{id}.flac\t{frames}", file=f)

if not (split == "test") and is_blind:
if split != "test" or not is_blind:
with open(os.path.join(manifest_dir, f"{split}.wrd"), "w") as f:
for text in df["normalized_text"].array:
print(text, file=f)
Expand Down
26 changes: 22 additions & 4 deletions slue_toolkit/prepare/prepare_voxpopuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re

from slue_toolkit.prepare import data_utils
from slue_toolkit.prepare.create_dict import create_dict

splits = {"fine-tune", "dev", "test"}

Expand Down Expand Up @@ -53,7 +54,7 @@ def create_manifest(
).frames
print(f"{uid}.ogg\t{frames}", file=f)

if not (split == "test") and is_blind:
if split != "test" or not is_blind:
with open(os.path.join(manifest_dir, f"{split}.wrd"), "w") as f:
for text in df["normalized_text"].array:
text = re.sub(r"[\.;?!]", "", text)
Expand All @@ -71,13 +72,13 @@ def create_manifest(
os.makedirs(os.path.join(manifest_dir, sub_dir_name), exist_ok=True)
for label_type in ["raw", "combined"]:
wrd_fn = os.path.join(
manifest_dir, "e2e_ner", f"{split}_{label_type}.wrd"
manifest_dir, "e2e_ner", f"{split}.{label_type}.wrd"
)
ltr_fn = os.path.join(
manifest_dir, "e2e_ner", f"{split}_{label_type}.ltr"
manifest_dir, "e2e_ner", f"{split}.{label_type}.ltr"
)
tsv_fn = os.path.join(
manifest_dir, "nlp_ner", f"{split}_{label_type}.tsv"
manifest_dir, "nlp_ner", f"{split}.{label_type}.tsv"
)
with open(wrd_fn, "w") as f_wrd, open(ltr_fn, "w") as f_ltr, open(
tsv_fn, "w"
Expand All @@ -94,6 +95,23 @@ def create_manifest(
)
print(wrd_str, file=f_wrd)
print(ltr_str, file=f_ltr)
try:
os.symlink(
f"../{split}.tsv", os.path.join(manifest_dir, f"e2e_ner/{split}.tsv")
)
except:
pass

for label_type in ["raw", "combined"]:
for token_type in ["wrd", "ltr"]:
create_dict(
os.path.join(
manifest_dir, f"e2e_ner/fine-tune.{label_type}.{token_type}"
),
os.path.join(
manifest_dir, f"e2e_ner/dict.{label_type}.{token_type}.txt"
),
)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion slue_toolkit/text_ner/ner_deberta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, fire
import os
import fire

import slue_toolkit.text_ner.ner_deberta_modules as NDM
from slue_toolkit.generic_utils import read_lst, load_pkl, save_pkl
Expand Down
Loading

0 comments on commit 78d06f3

Please sign in to comment.