Skip to content

Commit

Permalink
Merge pull request #57 from Hoorayeah/main
Browse files Browse the repository at this point in the history
update source code of spokenwoz
  • Loading branch information
tnlin committed Jan 8, 2024
2 parents dbe8725 + 17c546d commit fd015e3
Show file tree
Hide file tree
Showing 54 changed files with 4,872 additions and 73,318 deletions.
2 changes: 1 addition & 1 deletion spokenwoz/Finetuning/space_baseline/data_process.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

cd ../ubar
sh data_process.sh

cd data
mkdir ../../space_baseline/text_data/
cp -r multi-woz/* ../../space_baseline/text_data/
cp -r multi-woz-analysis/* ../../space_baseline/text_data/
cp -r multi-woz-processed/* ../../space_baseline/text_data/
Expand Down
9 changes: 8 additions & 1 deletion spokenwoz/Finetuning/space_baseline/move_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
cd ./text_data
mv data_for_damd.json data_for_space.json
cd ../
mkdir ./space_concat/space/data/multiwoz2.0/
mkdir ./space_word/space/data/multiwoz2.0/
mkdir ./space-3/space/data/multiwoz2.0/
cp -r ./text_data/* ./space_concat/space/data/multiwoz2.0/
cp -r ./text_data/* ./space_word/space/data/multiwoz2.0/
cp -r ./text_data/* ./space-3/space/data/multiwoz2.0/

mkdir ./space_concat/db/
mkdir ./space_word/db/
mkdir ./space-3/db/

cp -r ./db/* ./space_concat/db/
cp -r ./db/* ./space_word/db/
cp -r ./db/* ./space_3/db/
cp -r ./db/* ./space-3/db/
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,10 @@
from space.metrics.metrics_tracker import MetricsTracker
from transformers import Wav2Vec2Processor




def compute_jacc(data,default_cleaning_flag=True,type2_cleaning_flag=False):
num_turns = 0
joint_acc = 0
joint_acc_wo_cross = 0
joint_acc_wo_wrong = 0
joint_acc_wo_namestr = 0

error = {}
clean_tokens = ['<|endoftext|>', ]
dict_slot_acc_right = {}
Expand Down Expand Up @@ -90,64 +85,18 @@ def compute_jacc(data,default_cleaning_flag=True,type2_cleaning_flag=False):
for domain_slot in dict_slot_acc_right.keys():
dict_rate[domain_slot] = dict_slot_acc_right[domain_slot] / dict_slot_acc_all[domain_slot]

join_flag = False

turn_pred_wo_namestr = []
turn_target_wo_namestr = []
for item in turn_pred:
if 'namestr' not in item:
turn_pred_wo_namestr.append(item)
else:
pass
for item in turn_target:
if 'namestr' not in item:
turn_target_wo_namestr.append(item)
else:
pass

if set(turn_target_wo_namestr) == set(turn_pred_wo_namestr):
joint_acc_wo_namestr += 1
join_flag = True
elif type2_cleaning_flag: # check for possible Type 2 noisy annotations
flag = True
for bs in turn_target_wo_namestr:
if bs not in turn_pred_wo_namestr:
flag = False
break
if flag:
for bs in turn_pred_wo_namestr:
if bs not in turn_target_wo_namestr:
flag = False
break
if flag: # model prediction might be correct if found in Type 2 list of noisy annotations
dial_name = dial.split('.')[0]
if dial_name in IGNORE_TURNS_TYPE2 and turn_id in IGNORE_TURNS_TYPE2[dial_name]: # ignore these turns
pass
else:
joint_acc_wo_namestr += 1


join_flag = False



#卡掉莫名其妙的输出
turn_pred_wo_wrong = []
turn_target_wo_wrong = []
for item in turn_pred:
if 'emma' not in item and 'jerry' not in item and 'namestr' not in item:
turn_pred_wo_wrong.append(item)
else:
pass
for item in turn_target:
if 'emma' not in item and 'jerry' not in item and 'namestr' not in item:
turn_target_wo_wrong .append(item)
else:
pass

for instance_pred, instance_target in zip(turn_pred,turn_target):
if 'emma' not in instance_pred and 'jerry' not in instance_pred and 'john' not in instance_pred and 'micheal' not in instance_pred and 'emma' not in instance_target and 'jerry' not in instance_target and 'john' not in instance_target and 'micheal' not in instance_target:
turn_pred_wo_wrong.append(instance_pred)
turn_target_wo_wrong.append(instance_target)

if set(turn_target_wo_wrong) == set(turn_pred_wo_wrong):
joint_acc_wo_wrong += 1
joint_acc += 1
join_flag = True
elif type2_cleaning_flag: # check for possible Type 2 noisy annotations
flag = True
Expand All @@ -165,80 +114,9 @@ def compute_jacc(data,default_cleaning_flag=True,type2_cleaning_flag=False):
if dial_name in IGNORE_TURNS_TYPE2 and turn_id in IGNORE_TURNS_TYPE2[dial_name]: # ignore these turns
pass
else:
joint_acc_wo_wrong += 1




join_flag = False

turn_pred_wo_cross = []
turn_target_wo_cross = []
for item in turn_pred:
if '[profile]' not in item:
turn_pred_wo_cross.append(item)
else:
pass
for item in turn_target:
if '[profile]' not in item:
turn_target_wo_cross.append(item)
else:
pass

if set(turn_target_wo_cross) == set(turn_pred_wo_cross):
joint_acc_wo_cross += 1
join_flag = True
elif type2_cleaning_flag: # check for possible Type 2 noisy annotations
flag = True
for bs in turn_target_wo_cross:
if bs not in turn_pred_wo_cross:
flag = False
break
if flag:
for bs in turn_pred_wo_cross:
if bs not in turn_target_wo_cross:
flag = False
break
if flag: # model prediction might be correct if found in Type 2 list of noisy annotations
dial_name = dial.split('.')[0]
if dial_name in IGNORE_TURNS_TYPE2 and turn_id in IGNORE_TURNS_TYPE2[dial_name]: # ignore these turns
pass
else:
joint_acc_wo_cross += 1


joint_acc += 1


join_flag = False
# print('turn_pred ',turn_pred)
# print('turn_target',turn_target)
# print('turn_pred_wo_cross',set(turn_pred_wo_cross))
# print('turn_target_wo_cross',set(turn_target_wo_cross))
# print('turn_pred ',set(turn_pred))
# print('turn_target',set(turn_target))
if set(turn_target) == set(turn_pred):
joint_acc += 1
join_flag = True

elif type2_cleaning_flag: # check for possible Type 2 noisy annotations
flag = True
for bs in turn_target:
if bs not in turn_pred:
flag = False
break
if flag:
for bs in turn_pred:
if bs not in turn_target:
flag = False
break

if flag: # model prediction might be correct if found in Type 2 list of noisy annotations
dial_name = dial.split('.')[0]
if dial_name in IGNORE_TURNS_TYPE2 and turn_id in IGNORE_TURNS_TYPE2[dial_name]: # ignore these turns
pass
else:
joint_acc += 1
join_flag = True
if not join_flag:
if file_name not in error:
error[file_name] = {}
Expand All @@ -248,18 +126,14 @@ def compute_jacc(data,default_cleaning_flag=True,type2_cleaning_flag=False):

num_turns += 1


joint_acc /= num_turns
joint_acc_wo_cross /= num_turns
joint_acc_wo_namestr /= num_turns
joint_acc_wo_wrong /= num_turns
print('joint accuracy: {}'.format(joint_acc))
print('joint accuracy_wo_cross: {}'.format(joint_acc_wo_cross))
print('joint accuracy_wo_namestr: {}'.format(joint_acc_wo_cross))
print('joint_acc_wo_wrong: {}'.format(joint_acc_wo_wrong))
print('dict_rate: {}'.format(dict_rate))
with open('bs_error.json',"w") as f:
json.dump(error,f,indent=2)
return joint_acc, joint_acc_wo_cross, dict_rate
return joint_acc, dict_rate


def get_logger(log_path, name="default"):
logger = logging.getLogger(name)
Expand Down Expand Up @@ -645,7 +519,6 @@ def train_epoch(self, train_data, dev_data):
usr_audio.append(np.load(audio_dir))

transcripts.append(self.transcripts[dial_id]['log'][int(turn_id)*2])
#usr_audio.append(np.load(f'/mnt/mingz/mingz/spokenwoz/data/audio/{dial_id}{turn_id}-usr.npy'))
usr_audio = self.processor(usr_audio, sampling_rate=16000, padding=True, return_attention_mask=True,
return_tensors="pt")
# print(batch)
Expand Down Expand Up @@ -767,7 +640,7 @@ def infer(self, data_type='test'):
usr_audio = []
dial_id, turn_id = turn['dial_id'], turn['turn_num']
dial_id = dial_id[:3].upper() + dial_id[3:]
usr_audio.append(np.load(f'/data/nt12_hdd_gluster/myself/space3/space/data/audio/{dial_id}{turn_id}-usr.npy'))
usr_audio.append(np.load(f'../../audio/audio_5700_train_dev_npy/{dial_id}{turn_id}-usr.npy'))
transcripts = [self.transcripts[dial_id]['log'][int(turn_id)*2]]
usr_audio = self.processor(usr_audio, sampling_rate=16000, padding=True, return_attention_mask=True,return_tensors="pt")
outputs = self.func_model.infer(inputs=batch, start_id=prompt_id,audios=usr_audio,transcripts=transcripts,tokenizer=self.tokenizer,eos_id=self.reader.eos_r_id, max_gen_len=max_len)
Expand Down Expand Up @@ -923,8 +796,7 @@ def infer_dst(self, data_type='test'):
usr_audio = []
dial_id, turn_id = turn['dial_id'], turn['turn_num']
dial_id = dial_id[:3].upper() + dial_id[3:]
usr_audio.append(np.load(f'/data/nt12_hdd_gluster/myself/space3/space/data/audio/{dial_id}{turn_id}-usr.npy'))
#usr_audio.append(np.load(f'/mnt/mingz/mingz/spokenwoz/data/audio/{dial_id}{turn_id}-usr.npy'))
usr_audio.append(np.load(f'../../audio/audio_5700_train_dev_npy/{dial_id}{turn_id}-usr.npy'))
transcripts = [self.transcripts[dial_id]['log'][int(turn_id)*2]]
usr_audio = self.processor(usr_audio, sampling_rate=16000, padding=True, return_attention_mask=True,
return_tensors="pt")
Expand Down Expand Up @@ -1036,24 +908,22 @@ def infer_dst(self, data_type='test'):
results, _ = self.reader.wrap_result_lm(tmp_dialog_result)

#compute JGA for dst
joint_acc, joint_acc_wo_cross, dict_rate = compute_jacc(data=model_output)
joint_acc, dict_rate = compute_jacc(data=model_output)

# compute scores

# log results
metrics_message = 'joint_acc,: %2.2f' %\
(joint_acc)
metrics_message_plus = 'joint_acc_wo_cross,: %2.2f' %\
(joint_acc_wo_cross)

message_prefix = f"[Infer][{self.epoch}]"
time_cost = f"TIME-{time.time() - begin_time:.3f}"
message = " ".join([message_prefix, metrics_message, metrics_message_plus, time_cost])
message = " ".join([message_prefix, metrics_message, time_cost])
self.logger.info(message)

# save results
eval_results = {
'joint_acc': joint_acc,
'joint_acc_wo_cross': joint_acc_wo_cross,
'result': message
}
with open(infer_save_file, "w") as fp:
Expand Down
4 changes: 2 additions & 2 deletions spokenwoz/Finetuning/space_baseline/space_word/run_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def set_seed(seed):
if hparams.do_infer:
# infer process
if hparams.do_dst == True:
trainer.infer_dst(data_type='test')
trainer.infer_dst(data_type='dev')
else:
trainer.infer(data_type='test')
trainer.infer(data_type='dev')


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ python -u run_gen.py \
--max_len=1024 \
--max_ctx_turn=100 \
--token_loss=true \
--data_processed=test_for_space_encoded.data.json
--data_processed=valid_for_space_encoded.data.json
# --data_processed=data_for_spectra_encoded.data.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ python -u run_gen.py \
--max_len=1024 \
--max_ctx_turn=20 \
--token_loss=true \
--data_processed=data_for_space_encoded.data.json
--data_processed=valid_for_space_encoded.data.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ python -u run_gen.py \
--max_len=1024 \
--max_ctx_turn=20 \
--token_loss=true \
--data_processed=data_for_space_encoded.data.json
--data_processed=valid_for_space_encoded.data.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ SEED=10
SAVE_DIR=${SAVE_ROOT}/outputs/${DATA_NAME}${VERSION}/space-103

# Main run.
python -u run_gen.py \
python3 -u run_gen.py \
--do_train=true \
--model=${MODEL} \
--policy=${POLICY} \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,13 @@ def __init__(self, hparams):
self.understand_ids = self.tokenizer.convert_tokens_to_ids(self.understand_tokens)
self.policy_ids = self.tokenizer.convert_tokens_to_ids(self.policy_tokens)

test_list = [l.strip().lower() for l in open(
os.path.join(self.data_root, f'data/multiwoz{self.version}/testListFile.json'), 'r').readlines()]
# test_list = [l.strip().lower() for l in open(
# os.path.join(self.data_root, f'data/multiwoz{self.version}/testListFile.json'), 'r').readlines()]
dev_list = [l.strip().lower() for l in open(
os.path.join(self.data_root, f'data/multiwoz{self.version}/valListFile.json'), 'r').readlines()]
self.dev_files, self.test_files = {}, {}
for fn in test_list:
self.test_files[fn.replace('.json', '')] = 1
# for fn in test_list:
# self.test_files[fn.replace('.json', '')] = 1
for fn in dev_list:
self.dev_files[fn.replace('.json', '')] = 1

Expand Down Expand Up @@ -846,6 +846,7 @@ def convert_turn_eval(self, turn, pv_turn, first_turn=False):
if self.use_true_curr_bspn:
pv_context = pv_turn['labels'] + [pv_turn['aspn'] + pv_turn['resp']]
else:
# print(pv_turn)
pv_context = pv_turn['labels'] + [pv_turn['bspn'] + pv_turn['db'] + pv_turn['aspn'] + pv_turn['resp']]

# prompt response, add sos_r
Expand Down Expand Up @@ -1850,13 +1851,13 @@ def __init__(self, hparams):
self.understand_ids = self.tokenizer.convert_tokens_to_ids(self.understand_tokens)
self.policy_ids = self.tokenizer.convert_tokens_to_ids(self.policy_tokens)

test_list = [l.strip().lower() for l in open(
os.path.join(self.data_root, f'data/multiwoz{self.version}/testListFile.json'), 'r').readlines()]
# test_list = [l.strip().lower() for l in open(
# os.path.join(self.data_root, f'data/multiwoz{self.version}/testListFile.json'), 'r').readlines()]
dev_list = [l.strip().lower() for l in open(
os.path.join(self.data_root, f'data/multiwoz{self.version}/valListFile.json'), 'r').readlines()]
self.dev_files, self.test_files = {}, {}
for fn in test_list:
self.test_files[fn.replace('.json', '')] = 1
# for fn in test_list:
# self.test_files[fn.replace('.json', '')] = 1
for fn in dev_list:
self.dev_files[fn.replace('.json', '')] = 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(self, hparams, reader, generator, dtype="float32"):
self.with_mlm = hparams.with_mlm
self._dtype = dtype
self.audio_encoder = WavLMForMAM.from_pretrained('microsoft/wavlm-base-plus')

self.requires_grad_(requires_grad=False)
# self.lstm = nn.LSTM(768, 768, 1, bidirectional = True, batch_first = True)
# self.fused_encoder = TransformerBlock(self.hidden_dim,
Expand Down
Loading

0 comments on commit fd015e3

Please sign in to comment.