In [None]:
import torch
import csv
import json

def load_movie_data(csv_path, train_path, valid_path, test_path):
    # 首先获取所有在对话中被提到的电影ID
    id2link = {}
    mentioned_ids, new_id2link = get_mentioned_movie_ids(train_path, id2link)
    valid_mentioned_id, new_id2link = get_mentioned_movie_ids(valid_path, new_id2link)
    mentioned_ids.update(valid_mentioned_id)
    test_mentioned_id, new_id2link = get_mentioned_movie_ids(test_path, new_id2link)
    mentioned_ids.update(test_mentioned_id)

    # print(mentioned_ids)

    items_db = {}
    new_item_db = {}

    meta_path = '/projects/prjs1158/KG/redail/efficient_unified_crs_place/data/REDIAL/movie_db'

    meta_info = torch.load(meta_path)
    meta_dict = {}
    for key in meta_info.keys():
        meta = meta_info[key]
        # get the name before [SEP]
        title = meta.split('[SEP]')[0].strip()
        meta_dict[title] = " ".join(meta.split('[SEP]')[1:])

    name2meta = {}
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        cnt = 0
        for row in reader:
            movie_id = int(row['movieId'])
            # 只有当电影ID在对话中被提到过时才保存
            if movie_id in mentioned_ids:
                # Remove content inside brackets and strip any extra whitespace
                import re
                new_movie_name = re.sub(r'\s*\(.*?\)\s*', '', row['movieName']).strip()

                # lower case
                new_movie_name = new_movie_name.lower()
                # print(new_movie_name)
                # print(new_movie_name in meta_dict)
                if new_movie_name in meta_dict:  
                    name2meta[id2link[movie_id]] = meta_dict[new_movie_name]       
                    items_db[movie_id] = {
                        "movieName": new_movie_name,
                        'meta': meta_dict[new_movie_name],
                        "nbMentions": int(row['nbMentions']),
                        "new_item_id": cnt
                    }
                    new_item_db[cnt] = {
                        "movieName": new_movie_name,
                        'meta': meta_dict[new_movie_name],}
                else:
                    name2meta[id2link[movie_id]] = ''
                    items_db[movie_id] = {
                        "movieName": new_movie_name,
                        'meta': '',
                        "nbMentions": int(row['nbMentions']),
                        "new_item_id": cnt
                    }
                    new_item_db[cnt] = {
                        "movieName": new_movie_name,
                        'meta': '',}
                cnt += 1
    # print(items_db)
    return items_db, new_item_db, name2meta

def get_mentioned_movie_ids(data, id2link):
    with open(data, 'r', encoding='utf-8') as f:
        data = json.load(f)
    """从对话数据中提取所有被提到的电影ID（格式为@[ID]）"""
    mentioned_ids = set()
    # id2link = {}
    
    for conv in data:
        for turn in conv["dialog"]:
            text = " ".join(turn["text"])
            # 在文本中查找所有@开头的ID
            import re
            movie_mentions = re.findall(r'@(\d+)', text)
            mentioned_ids.update(movie_mentions)
            for cnt, movie_id in enumerate(movie_mentions):
                if cnt < len(turn["movies"]):
                    id2link[int(movie_id)] = turn["movies"][cnt]
                else:
                    id2link[int(movie_id)] = ''
    # change to int
    mentioned_ids = set([int(x) for x in mentioned_ids])
    return mentioned_ids, id2link

csv_path = '/projects/prjs1158/KG/redail/MESE_review/DATA/nltk/movies_with_mentions.csv'
train_path = '/projects/prjs1158/KG/redail/UniMIND_meta/data/redail/nltk/train_data.json'
valid_path = '/projects/prjs1158/KG/redail/UniMIND_meta/data/redail/nltk/valid_data.json'
test_path = '/projects/prjs1158/KG/redail/UniMIND_meta/data/redail/nltk/test_data.json'

items_db, new_item_db, name2meta = load_movie_data(csv_path, train_path, valid_path, test_path)

In [2]:
import json
import re

import html
from tqdm.auto import tqdm

movie_pattern = re.compile(r'@\d+')


def process_utt(utt, movieid2name, replace_movieId):
    def convert(match):
        movieid = match.group(0)[1:]
        if movieid in movieid2name:
            movie_name = movieid2name[movieid]
            movie_name = ' '.join(movie_name.split())
            return movie_name
        else:
            return match.group(0)

    if replace_movieId:
        utt = re.sub(movie_pattern, convert, utt)
    utt = ' '.join(utt.split())
    utt = html.unescape(utt)

    return utt


def process(data_file, out_file, movie_set):
    with open(data_file, 'r', encoding='utf-8') as fin, open(out_file, 'w', encoding='utf-8') as fout:
        for line in tqdm(fin):
            dialog = json.loads(line)
            if len(dialog['messages']) == 0:
                continue

            movieid2name = dialog['movieMentions']
            user_id, resp_id = dialog['initiatorWorkerId'], dialog['respondentWorkerId']
            context, resp = [], ''
            entity_list = []
            messages = dialog['messages']
            turn_i = 0
            while turn_i < len(messages):
                worker_id = messages[turn_i]['senderWorkerId']
                utt_turn = []
                entity_turn = []
                movie_turn = []

                turn_j = turn_i
                while turn_j < len(messages) and messages[turn_j]['senderWorkerId'] == worker_id:
                    utt = process_utt(messages[turn_j]['text'], movieid2name, replace_movieId=True)
                    utt_turn.append(utt)

                    entity_ids = [entity2id[entity] for entity in messages[turn_j]['entity'] if entity in entity2id]
                    entity_turn.extend(entity_ids)

                    movie_ids = [entity2id[movie] for movie in messages[turn_j]['movie'] if movie in entity2id]
                    movie_turn.extend(movie_ids)

                    turn_j += 1

                utt = ' '.join(utt_turn)

                # if worker_id == user_id:
                #     context.append(utt)
                #     entity_list.append(entity_turn + movie_turn)
                # else:
                resp = utt

                context_entity_list = [entity for entity_l in entity_list for entity in entity_l]
                context_entity_list_extend = []
                # entity_links = [id2entity[id] for id in context_entity_list if id in id2entity]
                # for entity in entity_links:
                #     if entity in node2entity:
                #         for e in node2entity[entity]['entity']:
                #             if e in entity2id:
                #                 context_entity_list_extend.append(entity2id[e])
                context_entity_list_extend += context_entity_list
                context_entity_list_extend = list(set(context_entity_list_extend))

                if len(context) == 0:
                    context.append('')
                turn = {
                    'context': context,
                    'resp': resp,
                    'rec': list(set(movie_turn + entity_turn)),
                    'entity': context_entity_list_extend,
                }
                fout.write(json.dumps(turn, ensure_ascii=False) + '\n')

                context.append(resp)
                entity_list.append(movie_turn + entity_turn)
                movie_set |= set(movie_turn)

                turn_i = turn_j



with open('entity2id.json', 'r', encoding='utf-8') as f:
    entity2id = json.load(f)
item_set = set()
# with open('node2text_link_clean.json', 'r', encoding='utf-8') as f:
#     node2entity = json.load(f)

process('valid_data_dbpedia.jsonl', 'valid_data_processed.jsonl', item_set)
process('test_data_dbpedia.jsonl', 'test_data_processed.jsonl', item_set)
process('train_data_dbpedia.jsonl', 'train_data_processed.jsonl', item_set)

with open('item_ids.json', 'w', encoding='utf-8') as f:
    json.dump(list(item_set), f, ensure_ascii=False)
print(f'#item: {len(item_set)}')


  from .autonotebook import tqdm as notebook_tqdm
1000it [00:00, 2549.04it/s]
1342it [00:00, 2828.35it/s]
9006it [00:02, 3094.39it/s]


#item: 6281


In [None]:
id2review = {}
for key in name2meta.keys():
    if key in entity2id:
        id = entity2id[key]
        id2review[id] = name2meta[key]

In [7]:
# save the dict of id2review
import pickle
with open('id2review.pkl', 'wb') as f:
    pickle.dump(id2review, f)