In [6]:
import torch
import csv
import json

def load_movie_data(csv_path, train_path, valid_path, test_path):
    # 首先获取所有在对话中被提到的电影ID
    id2link = {}
    mentioned_ids, new_id2link = get_mentioned_movie_ids(train_path, id2link)
    valid_mentioned_id, new_id2link = get_mentioned_movie_ids(valid_path, new_id2link)
    mentioned_ids.update(valid_mentioned_id)
    test_mentioned_id, new_id2link = get_mentioned_movie_ids(test_path, new_id2link)
    mentioned_ids.update(test_mentioned_id)

    # print(mentioned_ids)

    items_db = {}
    new_item_db = {}

    meta_path = '/projects/prjs1158/KG/redail/efficient_unified_crs_place/data/REDIAL/movie_db'

    meta_info = torch.load(meta_path)
    meta_dict = {}
    for key in meta_info.keys():
        meta = meta_info[key]
        # get the name before [SEP]
        title = meta.split('[SEP]')[0].strip()
        meta_dict[title] = " ".join(meta.split('[SEP]')[1:])

    name2meta = {}
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        cnt = 0
        for row in reader:
            movie_id = int(row['movieId'])
            # 只有当电影ID在对话中被提到过时才保存
            if movie_id in mentioned_ids:
                # Remove content inside brackets and strip any extra whitespace
                import re
                new_movie_name = re.sub(r'\s*\(.*?\)\s*', '', row['movieName']).strip()

                # lower case
                new_movie_name = new_movie_name.lower()
                # print(new_movie_name)
                # print(new_movie_name in meta_dict)
                if new_movie_name in meta_dict:  
                    name2meta[id2link[movie_id]] = meta_dict[new_movie_name]       
                    items_db[movie_id] = {
                        "movieName": new_movie_name,
                        'meta': meta_dict[new_movie_name],
                        "nbMentions": int(row['nbMentions']),
                        "new_item_id": cnt
                    }
                    new_item_db[cnt] = {
                        "movieName": new_movie_name,
                        'meta': meta_dict[new_movie_name],}
                else:
                    name2meta[id2link[movie_id]] = ''
                    items_db[movie_id] = {
                        "movieName": new_movie_name,
                        'meta': '',
                        "nbMentions": int(row['nbMentions']),
                        "new_item_id": cnt
                    }
                    new_item_db[cnt] = {
                        "movieName": new_movie_name,
                        'meta': '',}
                cnt += 1
    # print(items_db)
    return items_db, new_item_db, name2meta

def get_mentioned_movie_ids(data, id2link):
    with open(data, 'r', encoding='utf-8') as f:
        data = json.load(f)
    """从对话数据中提取所有被提到的电影ID（格式为@[ID]）"""
    mentioned_ids = set()
    # id2link = {}
    
    for conv in data:
        for turn in conv["dialog"]:
            text = " ".join(turn["text"])
            # 在文本中查找所有@开头的ID
            import re
            movie_mentions = re.findall(r'@(\d+)', text)
            mentioned_ids.update(movie_mentions)
            for cnt, movie_id in enumerate(movie_mentions):
                if cnt < len(turn["movies"]):
                    id2link[int(movie_id)] = turn["movies"][cnt]
                else:
                    id2link[int(movie_id)] = ''
    # change to int
    mentioned_ids = set([int(x) for x in mentioned_ids])
    return mentioned_ids, id2link

csv_path = '/projects/prjs1158/KG/redail/MESE_review/DATA/nltk/movies_with_mentions.csv'
train_path = '/projects/prjs1158/KG/redail/UniMIND_meta/data/redail/nltk/train_data.json'
valid_path = '/projects/prjs1158/KG/redail/UniMIND_meta/data/redail/nltk/valid_data.json'
test_path = '/projects/prjs1158/KG/redail/UniMIND_meta/data/redail/nltk/test_data.json'

items_db, new_item_db, name2meta = load_movie_data(csv_path, train_path, valid_path, test_path)

  meta_info = torch.load(meta_path)


In [7]:
cnt = 0
for key in name2meta.keys():
    print(key)
    print(name2meta[key])
    if cnt == 10:
        break

<http://dbpedia.org/resource/Angels_in_the_Outfield_(1994_film)>
 Danny Glover, Brenda Fricker, Tony Danza   William Dear   Comedy, Family, Fantasy   When a boy prays for a chance to have a family if the California Angels win the pennant, angels are assigned to make that possible.
<http://dbpedia.org/resource/Eddie_and_the_Cruisers>
 Tom Berenger, Michael Paré, Joe Pantoliano   Martin Davidson   Drama, Music, Mystery   A television newswoman picks up the story of a 1960s rock band whose long-lost leader - Eddie Wilson - may still be alive, while searching for the missing tapes of the band's never-released album.
<http://dbpedia.org/resource/Ninja_Assassin>
 Rain, Joon Lee, Jonathan Chan-Pensley   James McTeigue   Action, Thriller   A young ninja turns his back on the orphanage that raised him, leading to a confrontation with a fellow ninja from the clan.
<http://dbpedia.org/resource/Orgazmo>
 Trey Parker, Dian Bachar, Robyn Lynne Raab   Trey Parker   Comedy   Naive young Mormon Joe You

In [8]:
ori_test_data_path = './test_data_dbpedia.jsonl'
new_test_data_path = './test_data_dbpedia_review.jsonl'

from tqdm import tqdm
with open(ori_test_data_path, 'r') as f, open(new_test_data_path, 'w', encoding='utf-8') as fout:
    for line in tqdm(f):
        dialog = json.loads(line)
        messages = dialog['messages']
        for message in messages:
            # print(message['text'])
            if message['entity']:
                for entity_name in message['entity']:
                    if entity_name in name2meta.keys():
                        review = name2meta[entity_name]
                        message['text'] = message['text'] + " Meta: " + review
            # print(message['text'])
            if message['movie']:
                for entity_name in message['movie']:
                    if entity_name in name2meta.keys():
                        review = name2meta[entity_name]
                        # print(review)
                        message['text'] = message['text'] + " Meta: " + review
        
        # write to new file
        fout.write(json.dumps(dialog, ensure_ascii=False) + '\n')



0it [00:00, ?it/s]

1342it [00:00, 3095.39it/s]


In [10]:
ori_test_data_path = './valid_data_dbpedia.jsonl'
new_test_data_path = './valid_data_dbpedia_review.jsonl'

from tqdm import tqdm
with open(ori_test_data_path, 'r') as f, open(new_test_data_path, 'w', encoding='utf-8') as fout:
    for line in tqdm(f):
        dialog = json.loads(line)
        messages = dialog['messages']
        for message in messages:
            # print(message['text'])
            if message['entity']:
                for entity_name in message['entity']:
                    if entity_name in name2meta.keys():
                        review = name2meta[entity_name]
                        message['text'] = message['text'] + " Meta: " + review
            # print(message['text'])
            if message['movie']:
                for entity_name in message['movie']:
                    if entity_name in name2meta.keys():
                        review = name2meta[entity_name]
                        # print(review)
                        message['text'] = message['text'] + " Meta: " + review
        
        # write to new file
        fout.write(json.dumps(dialog, ensure_ascii=False) + '\n')

1000it [00:00, 3464.61it/s]


In [11]:
ori_test_data_path = './train_data_dbpedia.jsonl'
new_test_data_path = './train_data_dbpedia_review.jsonl'

from tqdm import tqdm
with open(ori_test_data_path, 'r') as f, open(new_test_data_path, 'w', encoding='utf-8') as fout:
    for line in tqdm(f):
        dialog = json.loads(line)
        messages = dialog['messages']
        for message in messages:
            # print(message['text'])
            if message['entity']:
                for entity_name in message['entity']:
                    if entity_name in name2meta.keys():
                        review = name2meta[entity_name]
                        message['text'] = message['text'] + " Meta: " + review
            # print(message['text'])
            if message['movie']:
                for entity_name in message['movie']:
                    if entity_name in name2meta.keys():
                        review = name2meta[entity_name]
                        # print(review)
                        message['text'] = message['text'] + " Meta: " + review
        
        # write to new file
        fout.write(json.dumps(dialog, ensure_ascii=False) + '\n')


9006it [00:02, 3524.17it/s]
