In [7]:
'''
parse news data from news.tsv
save to processed/news.pkl
format:
    key: news_id
    value: {
        "title": str,
        "embedding": torch.Tensor
    }
'''

import torch
import json
import os
import pickle
import numpy as np
import time
from tqdm import tqdm
from openai import OpenAI, RateLimitError

target_dir = 'MIND/small/valid'
target_file = os.path.join(target_dir, 'news.tsv')

In [8]:
class OpenAIServer:
    def __init__(self):
        self.openai_client = OpenAI(
            api_key="<YOUR_OPENAI_API_KEY>"
        )

    def embedding(self, sentences):
        """
        @param sentences: list[str]
        @return: np.array of shape (1536,)
        """
        response = self.openai_client.embeddings.create(
            input=sentences,
            model="text-embedding-3-small"
        )
        embeddings = [data.embedding for data in response.data]
        return np.array(embeddings)

server = OpenAIServer()

In [9]:
sampled_users = json.load(open('toy/processed/sampled_users.json'))
all_used_news = set()
for user in sampled_users:
    all_used_news.update(user['history_news'])
    for test in user['tests']:
        all_used_news.update(test['news_ids'])

In [10]:
len(sampled_users), len(all_used_news)

(33, 1483)

In [11]:
def parse_line(line):
    first_split = line.find('\t')
    second_split = line.find('\t', first_split+1)
    third_split = line.find('\t', second_split+1)
    fourth_split = line.find('http')
    news_id = line[:first_split]
    category = line[first_split+1:second_split]
    sub_category = line[second_split+1:third_split]
    title = line[third_split+1:fourth_split].replace('\t', ' ')
    content = "[{}-{}] {}".format(category, sub_category, title)
    return news_id, content

news_pool = {}
batch_size = 10000
with open(target_file, 'r') as f:
    lines = f.readlines()
    for i in tqdm(range(0, len(lines), batch_size)):
        batch_lines = lines[i:i+batch_size]
        batch_data = [parse_line(line) for line in batch_lines]
        batch_data = [data for data in batch_data if data[0] in all_used_news]  # Referenced, need to be encoded
        batch_ids = [data[0] for data in batch_data]
        batch_contents = [data[1] for data in batch_data]
        batch_embeddings = server.embedding(batch_contents)
        for j in range(len(batch_ids)):
            news_id, content, embedding = batch_ids[j], batch_contents[j], batch_embeddings[j]
            news_pool[news_id] = {
                "title": content,
                "embedding": embedding
            }

100%|██████████| 5/5 [00:10<00:00,  2.17s/it]


In [12]:
output_dir = 'toy/processed'
with open(os.path.join(output_dir, 'news.pkl'), 'wb') as f:
    pickle.dump(news_pool, f)