In [None]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
import os,json
from dataclasses import dataclass
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    HfArgumentParser,
    TrainingArguments,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    pipeline,
)
import torch
from torch.utils.data.dataset import Dataset
import datasets
from datasets import load_dataset, load_metric
# from sysproxy import SysProxy
# sys_proxy = SysProxy()

In [None]:
def load_model(modelpath,device='cuda'):
    # config = AutoConfig.from_pretrained(modelname)
    model = AutoModelForCausalLM.from_pretrained(modelpath).to(device)
    tokenizer = AutoTokenizer.from_pretrained(modelpath)
    return model, tokenizer

In [None]:
# taskname = "NLG"
# taskname = "NLU"
# source_fname = f'/home/jitianbo/Workspace/driver_simulator_kvret/data/data_for_clm/{taskname}/test-{taskname}.source'
# target_fname = f'/home/jitianbo/Workspace/driver_simulator_kvret/data/data_for_clm/{taskname}/test-{taskname}.target'

source_fname = '/home/jitianbo/Workspace/driver_simulator_kvret/data/data_for_clm/test.source'
target_fname = '/home/jitianbo/Workspace/driver_simulator_kvret/data/data_for_clm/test.target'
with open(source_fname) as f:
    source_data = f.read().strip().splitlines()
with open(target_fname) as f:
    target_data = f.read().strip().splitlines()
# testset = load_dataset('text',data_files={"test": testset_fname})
# testset = testset['test']

In [None]:
Path.joinpath?

In [None]:
# modeldir = '/home/jitianbo/Workspace/driver_simulator_kvret/simulator/clm-output/'.rstrip(r'/')
# modelname = 'gpt2'

# modelpath = f"{modeldir}/{modelname}/"

# pipeline_task = "text-generation"
# device = 'cuda'
# p = pipeline(
#     task=pipeline_task,
#     model=model,
#     tokenizer=tokenizer,
#     device=0,
#     batch_size=8,
#     max_length=512,
# )

In [None]:
dset = 'test'

In [None]:
modeldir = '/home/jitianbo/Workspace/driver_simulator_kvret/simulator/clm-output/'.rstrip(r'/')
modelname = 'distilgpt2'
modelpath = f"{modeldir}/{modelname}/"

dset = 'test'
device = 'cuda'
model, tokenizer = load_model(modelpath,device=device)


In [None]:
def detect_task_by_source(source_text):
    # NUL task: assistant utterance -> assistant action
    if source_text.endswith('[eoau]'): 
        return "NLU"
    # POL task: assistant action -> driver action
    elif source_text.endswith('[eoaa]'): 
        return "POL"
    # NLG task: driver action -> driver utterance
    else:
        return "NLG"
def get_eos_by_task(task,tokenizer):
    eos_dict = {
        "NLU": "[eoaa]",
        "POL": "[eoda]",
        "NLG": "[eodu]",
    }

    # assert task in eos_dict
    return tokenizer.encode(eos_dict[task])[0]

def get_sos_by_task(task,tokenizer):
    sos_dict = {
        "NLU": "[soaa]",
        "POL": "[soda]",
        "NLG": "[sodu]",
    }
    return tokenizer.encode(sos_dict[task])[0]

# def decode_by_task(generated_ori,tokenizer,task):
#     generated = generated_ori.cpu().numpy()
#     sos_id = get_sos_by_task(task,tokenizer)
#     eos_id = get_eos_by_task(task,tokenizer)
#     if sos_id not in generated or eos_id not in generated:
#         return generated
    
#     sos_idx = generated.index(sos_id)
#     eos_idx = generated.index(eos_id)
#     if sos_idx < eos_idx:
#         to_decode = generated[:, sos_idx:eos_idx+1]
#     else:
#         sos_idxes = np.where(arr == 15)
#     return to_decode

def process_generated_text(result,task):
    words = result.split()
    eos_dict = {
        "NLU": "[eoaa]",
        "POL": "[eoda]",
        "NLG": "[eodu]",
    }
    sos_dict = {
        "NLU": "[soaa]",
        "POL": "[soda]",
        "NLG": "[sodu]",
    }
    sos_token = sos_dict[task]
    eos_token = eos_dict[task]
    if sos_token not in words or eos_token not in words:
        return result
    sos_id = words.index(sos_token)
    eos_id = words.index(eos_token)
    if sos_id < eos_id:
        tokens = words[sos_id:eos_id+1]
        return ' '.join(tokens)
    words_np = np.array(words)
    # sos_ids,*_ = np.where(words_np==sos_token)
    eos_ids,*_ = np.where(words_np==eos_token)
    
    # 如果sos的第一个id比eos最后一个id还要大，说明不存在valid的数据，原样返回
    if sos_id > eos_ids[-1]:
        return result
    
    for e in eos_ids:
        if e > sos_id:
            break
    eos_id = e
    tokens = words[sos_id:eos_id+1]
    return ' '.join(tokens)

def remove_special_sep_tokens(text):
    special_sep_tokens = [
        '[eoaa]', '[eoau]', '[eoda]', '[eodp]', '[eodu]',
        # '[eoaa', '[eoau', '[eoda', '[eodp', '[eodu',
     
        '[soaa]', '[soau]', '[soda]', '[sodp]', '[sodu]',
#         '[soaa', '[soau', '[soda', '[sodp', '[sodu',
        
#         '[eoa', '[eod', '[eo', '[e', '[',
#         '[soa', '[sod', '[so', '[s', 
        
    ]
    text = text.replace("][","] [")
    words = text.split()
    tokens = [e for e in words if e not in special_sep_tokens]
    tokens = [e for e in tokens if not (e.startswith("[") and not e.endswith("]"))]
    return ' '.join(tokens)

def process_raw_result(raw_result,task):
    return remove_special_sep_tokens(process_generated_text(raw_result,task))

In [None]:
# i = 1
# input_text = testset[i]['text']
# input_ids = tokenizer(input_text, return_tensors="pt").to(device).input_ids
# # input_ids = tokenizer.encode(input_text,return_tensors="pt").view(-1).to(device)
# input_len = input_ids.shape[-1]
# # max_length=context_length+max_len, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.encode(['<eos_r>'])[0])
# outputs = model.generate(input_ids, max_length=512,eos_token_id=self.tokenizer.encode(['[eo]'])[0])
# outputs_len = input_ids.shape[-1]
# # results = tokenizer.decode(outputs[:,input_len:], skip_special_tokens=False)
# results = tokenizer.batch_decode(outputs, skip_special_tokens=False,clean_up_tokenization_spaces=False)
# result = results[0]
# reference = references[i]

raw_result_dict = {
    "NLU": [],
    "NLG": [],
    "POL": [],
}
result_dict = {
    "NLU": [],
    "NLG": [],
    "POL": [],
}
target_dict = {
    "NLU": [],
    "NLG": [],
    "POL": [],
}
save_dir = Path('./inference-results').absolute().resolve()
save_dir.mkdir(exist_ok=True)
model_dir = save_dir.joinpath(f"{modelname}")
model_dir.mkdir(exist_ok=True)

for i,source_text in enumerate(tqdm(source_data)):
    input_text = source_text
    
    input_ids = tokenizer(input_text, return_tensors="pt").to(device).input_ids
    input_len = input_ids.shape[-1]
    max_len = 80
    # # NUL task: assistant utterance -> assistant action

    # if source_text.endswith('[eoau]'): 
    #     eos_token_id = tokenizer.encode('[eoaa]')[0]
    # # POL task: assistant action -> driver action
    # elif source_text.endswith('[eoaa]'): 
    #     eos_token_id = tokenizer.encode('[eoda]')[0]
    # # NLG task: driver action -> driver utterance
    # elif source_text.endswith('[eoda]'): 
    #     eos_token_id = tokenizer.encode('[eodu]')[0]
    
    task = detect_task_by_source(source_text)
    eos_token_id = get_eos_by_task(task,tokenizer)
    
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        pad_token_id = tokenizer.eos_token_id
    elif  tokenizer.eos_token_id is None:
        pad_token_id = eos_token_id
    else:
        pad_token_id = tokenizer.pad_token_id
    outputs = model.generate(
        input_ids, 
        max_length=input_len+max_len,
        temperature=0.7,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
    )
    opt = tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]
    generated = outputs[:,input_len:]
    # to_decode = decode_by_task(generated,tokenizer,task)
    # results = tokenizer.batch_decode(to_decode, skip_special_tokens=False,clean_up_tokenization_spaces=False)
    results = tokenizer.batch_decode(generated, skip_special_tokens=False,clean_up_tokenization_spaces=True)
    raw_result = results[0].strip()
    raw_result_dict[task].append(raw_result)
    
    result = process_raw_result(raw_result,task)
    result_dict[task].append(result)
    
    raw_target = target_data[i]
    target = remove_special_sep_tokens(raw_target)
    target_dict[task].append(target)
    
    # generated = outputs[0].numpy().tolist()
    
    

In [None]:
save_dir

In [None]:
for t in ["NLU","POL","NLG"]:
    raw_result_t = raw_result_dict[t]
    savefname = f"{dset}-{t}.raw"
    savefpath = model_dir.joinpath(savefname)
    with savefpath.open('w') as f:
        f.writelines([f"{e}\n" for e in raw_result_t])
    
    
    result_t = result_dict[t]
    savefname = f"{dset}-{t}.result"
    savefpath = model_dir.joinpath(savefname)
    with savefpath.open('w') as f:
        f.writelines([f"{e}\n" for e in result_t])
    
    target_t = target_dict[t]
    savefname = f"{dset}-{t}.target"
    savefpath = model_dir.joinpath(savefname)
    with savefpath.open('w') as f:
        f.writelines([f"{e}\n" for e in target_t])
    