In [21]:
import torch
import json
from nltk.corpus import wordnet, stopwords
from tqdm import tqdm
import nltk
from model.comet import Comet
from transformers import AdamW
import pickle
import numpy as np
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [22]:
DATA_FILES = lambda data_dir: {
    "train": [
        f"{data_dir}/sys_dialog_texts.train.npy",
        f"{data_dir}/sys_target_texts.train.npy",
        f"{data_dir}/sys_emotion_texts.train.npy",
        f"{data_dir}/sys_situation_texts.train.npy",
    ],
    "dev": [
        f"{data_dir}/sys_dialog_texts.dev.npy",
        f"{data_dir}/sys_target_texts.dev.npy",
        f"{data_dir}/sys_emotion_texts.dev.npy",
        f"{data_dir}/sys_situation_texts.dev.npy",
    ],
    "test": [
        f"{data_dir}/sys_dialog_texts.test.npy",
        f"{data_dir}/sys_target_texts.test.npy",
        f"{data_dir}/sys_emotion_texts.test.npy",
        f"{data_dir}/sys_situation_texts.test.npy",
    ],
}

word_pairs = {
    "it's": "it is",
    "don't": "do not",
    "doesn't": "does not",
    "didn't": "did not",
    "you'd": "you would",
    "you're": "you are",
    "you'll": "you will",
    "i'm": "i am",
    "they're": "they are",
    "that's": "that is",
    "what's": "what is",
    "couldn't": "could not",
    "i've": "i have",
    "we've": "we have",
    "can't": "cannot",
    "i'd": "i would",
    "i'd": "i would",
    "aren't": "are not",
    "isn't": "is not",
    "wasn't": "was not",
    "weren't": "were not",
    "won't": "will not",
    "there's": "there is",
    "there're": "there are",
}

emo_map = {
    "surprised": 0,
    "excited": 1,
    "annoyed": 2,
    "proud": 3,
    "angry": 4,
    "sad": 5,
    "grateful": 6,
    "lonely": 7,
    "impressed": 8,
    "afraid": 9,
    "disgusted": 10,
    "confident": 11,
    "terrified": 12,
    "hopeful": 13,
    "anxious": 14,
    "disappointed": 15,
    "joyful": 16,
    "prepared": 17,
    "guilty": 18,
    "furious": 19,
    "nostalgic": 20,
    "jealous": 21,
    "anticipating": 22,
    "embarrassed": 23,
    "content": 24,
    "devastated": 25,
    "sentimental": 26,
    "caring": 27,
    "trusting": 28,
    "ashamed": 29,
    "apprehensive": 30,
    "faithful": 31,
}



In [23]:
relations = ["xIntent", "xAttr", "xWant", "xEffect", "xNeed","xReact"]
emotion_lexicon = json.load(open("data/NRCDict.json"))[0]
stop_words = stopwords.words("english")


class Lang:
    def __init__(self, init_index2word):
        self.word2index = {str(v): int(k) for k, v in init_index2word.items()}##word:index
        self.word2count = {str(v): 1 for k, v in init_index2word.items()}##word:count
        self.index2word = init_index2word
        self.n_words = len(init_index2word)

    def index_words(self, sentence):
        for word in sentence:
            self.index_word(word.strip())

    def index_word(self, word): ##把没见过的词加入词典中
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1


def get_wordnet_pos(tag):
    if tag.startswith("J"):
        return wordnet.ADJ
    elif tag.startswith("V"):
        return wordnet.VERB
    elif tag.startswith("N"):
        return wordnet.NOUN
    elif tag.startswith("R"):
        return wordnet.ADV
    else:
        return None

def process_sent(sentence):  ##处理缩写
    sentence = sentence.lower()
    for k, v in word_pairs.items():
        sentence = sentence.replace(k, v)
    sentence = nltk.word_tokenize(sentence)
    return sentence


def get_commonsense(comet, item, data_dict):
    cs_list = []
    input_event = " ".join(item)
    for rel in relations:
        cs_res = comet.generate(input_event, rel)
        cs_res = [process_sent(item) for item in cs_res]
        cs_list.append(cs_res)

    data_dict["utt_cs"].append(cs_list)

def encode_ctx(vocab, items, data_dict, comet):
    for ctx in tqdm(items):
        ctx_list = []
        e_list = []
        for i, c in enumerate(ctx):
            item = process_sent(c)
            ctx_list.append(item)
            vocab.index_words(item)
            ws_pos = nltk.pos_tag(item)  # pos
            for w in ws_pos:
                w_p = get_wordnet_pos(w[1])
                if w[0] not in stop_words and (
                    w_p == wordnet.ADJ or w[0] in emotion_lexicon
                ):
                    e_list.append(w[0])

            get_commonsense(comet, item, data_dict)

        data_dict["context"].append(ctx_list)
        data_dict["emotion_context"].append(e_list)
        
def encode(vocab, files):

    data_dict = {
        "context": [],
        "target": [],
        "emotion": [],
        "situation": [],
        "emotion_context": [],
        "utt_cs": [],
    }
    comet = Comet("data/Comet", device)

    for i, k in enumerate(data_dict.keys()):
        items = files[i]
        if k == "context":
            encode_ctx(vocab, items, data_dict, comet)
        elif k == "emotion":
            data_dict[k] = items
        else:
            for item in tqdm(items):
                item = process_sent(item)
                data_dict[k].append(item)
                vocab.index_words(item)
        if i == 3:
            break
    assert (
        len(data_dict["context"])
        == len(data_dict["target"])
        == len(data_dict["emotion"])
        == len(data_dict["situation"])
        == len(data_dict["emotion_context"])
        == len(data_dict["utt_cs"])
    )

    return data_dict

In [24]:
vocab=Lang(
                {
                    0: "UNK",
                    1: "PAD",
                    2: "EOS",
                    3: "SOS",
                    4: "USR",
                    5: "SYS",
                    6: "CLS",
                }
            )
files = DATA_FILES("data/ED")
train_files = [np.load(f, allow_pickle=True) for f in files["train"]]
dev_files = [np.load(f, allow_pickle=True) for f in files["dev"]]
test_files = [np.load(f, allow_pickle=True) for f in files["test"]]

data_train = encode(vocab, train_files)
data_dev = encode(vocab, dev_files)
data_test = encode(vocab, test_files)

  2%|▏         | 779/40250 [16:52<14:04:47,  1.28s/it]

In [None]:
cache_file = "dataset_preproc1.2.p"
with open(cache_file, "wb") as f:
    pickle.dump([data_train, data_dev, data_test, vocab], f)
    print("Saved PICKLE")

Saved PICKLE
