In [1]:
import os
import json
import pandas as pd

from tqdm import tqdm
from langchain import PromptTemplate

from src.DST.dst import SLOTS_DESCRIPTIONS
from src.DST.config import CONFIG

from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration, pipeline

class PromptConstructor():
    def __init__(self, 
                 config):
        self.config = config
        self.instructions = config["INSTRUCTIONS"]
        self.prompt_templates = config["PROMPT_TEMPLATES"]
        
    def _get_slots_from_domains(self, domains, with_slot_description, with_req_inf_differentiation, with_all_slots):
        # slot_description = self.config["slot_descrpition"]
        if with_all_slots:
            domains = "all"
        
        if with_slot_description:
            with_req_inf_differentiation = False #Slot description is the discriminator

        if domains == "all":
            if with_req_inf_differentiation:
                req_slots = ", ".join(self.config["multiwoz21"]["all_requestable_slots"])
                inf_slots = ", ".join(self.config["multiwoz21"]["all_informable_slots"])
            else:
                slots = set(self.config["multiwoz21"]["all_requestable_slots"] + 
                            self.config["multiwoz21"]["all_informable_slots"])
                slots = ", ".join(slots)
        elif not isinstance(domains, list):
            raise ValueError("""Provided domain should be either 'all' or list of valid domain names:
                                - for multiwoz2.1 and 2.4: taxi, restaurant, hotel, train, attraction 
                                - for SGD: To-do""")
        else:
            req_slots = ""
            inf_slots = ""
            domain_req_slots = []
            domain_inf_slots = []
            for domain in domains:
                domain_req_slots += self.config["multiwoz21"]["requestable_slots"][domain]
                domain_inf_slots += self.config["multiwoz21"]["informable_slots"][domain]
            if with_req_inf_differentiation:
                domain_req_slots = set(domain_req_slots)
                domain_inf_slots = set(domain_inf_slots)
                req_slots += ", ".join(domain_req_slots)
                inf_slots += ", ".join(domain_inf_slots)
            else:
                slots = set(domain_req_slots + domain_inf_slots)
                slots = ", ".join(slots)

        if with_req_inf_differentiation:
            slots_info = f"Requestable slots: {req_slots}\nInformable slots: {inf_slots}"
        else:
            slots_info = f"{slots}"

        if with_slot_description:
            slots = slots.split(", ")
            slots_info = ""
            for slot in slots:
                if slot not in self.config["multiwoz21"]["all_informable_slots"]:
                    continue
                slots_info += f"name: {slot}, description: {SLOTS_DESCRIPTIONS[slot]}\n"
            slots_info = slots_info[:-2]
        
        return slots_info
    
    
    def _build_prompt(self, mode="", dialogue_context="", ontology="", slots="", dialogue_acts="", belief_states=""):
        prompt = ""
        if mode == "dst":
            instruction = self.instructions["instruction_with_slots"]
            template_variables = self.prompt_templates["template_with_slots"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                     slots=slots,
                                     dialogue_context=dialogue_context)
            
        elif mode == "dst_recorrect":
            instruction = self.instructions["instruction_with_slots_recorrect"]
            template = self.prompt_templates["template_with_slots_recorrect"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])            
            prompt = template.format(instruction=instruction,
                                    slots=slots,
                                    dialogue_context=dialogue_context,
                                    belief_states=belief_states)
            
        elif mode == "database_query":
            instruction = self.instructions["instruction_query_database"]
            template = self.prompt_templates["template_query_database"]
            template = PromptTemplate(input_variables= template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                    belief_states=belief_states)
            
        elif mode == "response_generation":
            instruction = self.instructions["instruction_response_generation"]
            template = self.prompt_templates["template_response_generation"]
            template = PromptTemplate(input_variables = template_variables["input_variables"],
                                      template = template_variables["template"])
            prompt = template.format(instruction=instruction,
                                    dialogue_acts=dialogue_acts,
                                    dialogue_context=dialogue_context)
        elif mode == "dst_extracted_ontology":
            pass

        else:
            raise ValueError("'mode' should be one of: [dst, dst_recorrect, database_query, response_generation]")
        
        return prompt


class MWOZ_Dataset(PromptConstructor):
    def __init__(self,
                 config,
                 mwoz_path,
                 dialog_history_limit,
                 with_slot_description,
                 with_req_inf_differentiation,
                 single_domain_only,
                 with_all_slots):
        PromptConstructor.__init__(self, config)
        self.dataset = {"id":[],
                        "dialogue_id":[],
                        "dialogue_context":[],
                        "turn":[],
                        "prompt":[],
                        "domains":[],
                        "gold_bs":[],
                        "gold_act":[],
                        "gold_response":[],
                        "gold_database_result":[],
                        }
        self.all_data, self.testfiles = self._get_mwoz_data(mwoz_path)
        self.idx = 0
        self.dialog_history_limit = dialog_history_limit
        self.single_domain_only = single_domain_only
        self.with_slot_description = with_slot_description
        self.with_req_inf_differentiation = with_req_inf_differentiation
        self.with_all_slots = with_all_slots

        print("Processing mwoz...")
        for sample in tqdm(self.all_data):
            if sample in self.testfiles:
                dialogue_log = self.all_data[sample]["log"]
                self._process_dialogue_log(sample=sample,
                                           dialogue_log=dialogue_log)

        self.dataset = pd.DataFrame(self.dataset)
        if single_domain_only:
            for index, row in tqdm(self.dataset.iterrows()):
                if len(row["domains"]) != 1:
                    self.dataset.drop(index, inplace=True)

    def _get_mwoz_data(self, mwoz_path):
        data_path = os.path.join(mwoz_path, "data.json")
        testListFile_path = os.path.join(mwoz_path, "testListFile.txt")

        with open(data_path, "r") as f:
            all_data = json.load(f)
            
        with open(testListFile_path, "r") as f:
            testfiles = f.read()
        testfiles = testfiles.split("\n")
        return all_data, testfiles
    
    def _process_dialogue_log(self, sample, dialogue_log):

        dialog_history_memory = []
        dialog_history = ""
        domains = self._get_domains_from_log(dialogue_log)
        slots = self._get_slots_from_domains(domains, 
                                             self.with_slot_description,
                                             self.with_req_inf_differentiation,
                                             self.with_all_slots) # or all

        for turn_nb, turn in enumerate(dialogue_log):

            if turn_nb % 2 == 0:
                speaker = "USER"
            else:
                speaker = "SYSTEM"

            utterance = f"""{speaker}: {turn["text"]}\n"""
            dialogue_context = dialog_history + utterance
            dialog_act = turn["dialog_act"]
            prompt = self._build_prompt(mode="dst",
                                        slots=slots,
                                        dialogue_context=dialogue_context) 


            if self.dialog_history_limit != 0:
                if self.dialog_history_limit == -1:
                    self.dialog_history_limit = len(dialogue_log)

                if len(dialog_history_memory) >= self.dialog_history_limit:
                    dialog_history_memory.pop(0)
                dialog_history_memory.append(utterance)
                dialog_history = "".join(dialog_history_memory)

            self.idx += 1
            if turn_nb % 2 == 0:
                self.dataset["gold_bs"].append(dialog_act)
                self.dataset["dialogue_context"].append(dialogue_context)
                self.dataset["gold_database_result"].append(None) 
                self.dataset["turn"].append(turn_nb//2)
                self.dataset["domains"].append(domains)
                self.dataset["id"].append(self.idx//2)
                self.dataset["dialogue_id"].append(sample)
                self.dataset["prompt"].append(prompt)
            else:
                self.dataset["gold_response"].append(utterance)
                self.dataset["gold_act"].append(dialog_act)


    def _get_domains_from_log(self, dialogue_log):
        domains = []
        all_domains = ["restaurant", "taxi", "hotel", "train", "attraction"]
        for log in dialogue_log:
            for domain_act in log["dialog_act"]:
                domain = domain_act.split("-")[0].lower()
                if domain in all_domains and domain not in domains:
                    domains.append(domain)
        return domains

2023-05-28 01:56:37.515872: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-28 01:56:38.450455: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-05-28 01:56:38.450564: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
mwoz_path = "/home/willy/instructod/MultiWOZ_2.1/"
dialog_history_limit = 0
single_domain_only = False
with_slot_description = False
with_req_inf_differentiation = False
with_all_slots = True
mwoz = MWOZ_Dataset(CONFIG, 
                    mwoz_path,
                    dialog_history_limit,
                    with_slot_description,
                    with_req_inf_differentiation,
                    single_domain_only,
                    with_all_slots)
dataset = mwoz.dataset

Processing mwoz...


100%|██████████| 10438/10438 [00:00<00:00, 13068.15it/s]


In [3]:
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to utilize.
    """
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "The path of the HuggingFace model."}
    )
    use_int8: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use int8 model or not."}
    )
    use_deepspeed: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use deepspeed model or not."}
    )

In [4]:
model_args = ModelArguments(model_name_or_path="google/flan-t5-xxl",
                            use_int8=True)

In [5]:
def load_pipeline(model_args):
    print(f'Loading {model_args.model_name_or_path}...')
    
    if model_args.use_int8:
        if "t5" in model_args.model_name_or_path or "flan-ul" in model_args.model_name_or_path:
            tokenizer = T5Tokenizer.from_pretrained(model_args.model_name_or_path)
            model = T5ForConditionalGeneration.from_pretrained(
                model_args.model_name_or_path, device_map="auto", load_in_8bit=True)
            generator = pipeline(
                task="text2text-generation", model=model, tokenizer=tokenizer, device_map="auto")
            # results = generator(prompt_dict["prompts"])
            # prompt_dict["results"] = [result["generated_text"] for result in results]
            return generator
        elif "flan-alpaca" in model_args.model_name_or_path:
            tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_args.model_name_or_path, device_map="auto", load_in_8bit=True)
            generator = pipeline(
                task="text2text-generation", model=model, tokenizer=tokenizer, device_map="auto")
            # results = generator(prompt_dict["prompts"], max_length=300)
            # prompt_dict["results"] = [result["generated_text"] for result in results]
            return generator

In [6]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [7]:
generator = load_pipeline(model_args)

Loading google/flan-t5-xxl...


Downloading (…)"spiece.model";:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/674 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/50.8k [00:00<?, ?B/s]

Downloading (…)00001-of-00005.bin";:   0%|          | 0.00/9.45G [00:00<?, ?B/s]

Downloading (…)00002-of-00005.bin";:   0%|          | 0.00/9.60G [00:00<?, ?B/s]

In [20]:
prompt = list(dataset["prompt"][0:5])

In [21]:
results = generator(prompt)

In [22]:
results

[{'generated_text': "slot1:'destination', 'Pizza Hut Fen Dit"},
 {'generated_text': "slot1:'leaveat', '17:15'"},
 {'generated_text': 'SYSTEM: Thank you!'},
 {'generated_text': 'SYSTEM: You are all set.'},
 {'generated_text': 'SYSTEM: Nusha is a restaurant.'}]