In [10]:
"""
build user samples from behaviors.tsv
save to processed/sampled_users.json
format:
    [
        {
            "history_news": list of news_id,
            "tests": list of {
                "news_ids": list of news_id,
                "labels": list of label
            }
        }
    ]
"""
import torch
import json
import os
import pickle
import numpy as np
import pandas as pd
import time
import random
from tqdm import tqdm

NUMBER_OF_NEGATIVE_NEWS = 4
NUMBER_PER_GROUP = 3  # for demo, reset it for your own purpose

target_dir = 'MIND/small/valid'
target_file = os.path.join(target_dir, 'behaviors.tsv')

In [11]:
# load target file as dataframe
df = pd.read_csv(target_file, sep='\t', header=None)

In [12]:
df.head()

Unnamed: 0,0,1,2,3,4
0,1,U80234,11/15/2019 12:37:50 PM,N55189 N46039 N51741 N53234 N11276 N264 N40716...,N28682-0 N48740-0 N31958-1 N34130-0 N6916-0 N5...
1,2,U60458,11/15/2019 7:11:50 AM,N58715 N32109 N51180 N33438 N54827 N28488 N611...,N20036-0 N23513-1 N32536-0 N46976-0 N35216-0 N...
2,3,U44190,11/15/2019 9:55:12 AM,N56253 N1150 N55189 N16233 N61704 N51706 N5303...,N36779-0 N62365-0 N58098-0 N5472-0 N13408-0 N5...
3,4,U87380,11/15/2019 3:12:46 PM,N63554 N49153 N28678 N23232 N43369 N58518 N444...,N6950-0 N60215-0 N6074-0 N11930-0 N6916-0 N248...
4,5,U9444,11/15/2019 8:25:46 AM,N51692 N18285 N26015 N22679 N55556,N5940-1 N23513-0 N49285-0 N23355-0 N19990-0 N3...


In [13]:
def parse_user(history_str, data_str):
    history_news = history_str.split(' ')
    test_data = [d.split('-') for d in data_str.split(' ')]
    positive_news = [t[0] for t in test_data if t[1] == '1']
    negative_news = [t[0] for t in test_data if t[1] == '0']
    tests = []
    for pos in positive_news:
        # sample NUMBER_OF_NEGATIVE_NEWS negative news
        negs = random.sample(negative_news, NUMBER_OF_NEGATIVE_NEWS)
        tests.append({
            'news_ids': [pos] + negs,
            'labels': [1] + [0] * NUMBER_OF_NEGATIVE_NEWS
        })
    return {
        'history_news': history_news,
        'tests': tests
    }

In [14]:
all_users = []
for uid in tqdm(range(len(df))):
    history_str = df.iloc[uid][3]
    data_str = df.iloc[uid][4]
    try:
        user_data = parse_user(history_str, data_str)
        all_users.append(user_data)
    except Exception as e:
        continue

100%|██████████| 73152/73152 [00:05<00:00, 14238.08it/s]


In [15]:
def get_groupid_by_length(length):
    return min(length // 10, 10) if length < 110 else -1

grouped_users = {id: [] for id in range(11)}
for user in all_users:
    group_id = get_groupid_by_length(len(user['history_news']))
    if group_id >= 0:
        grouped_users[group_id].append(user)
print([len(grouped_users[id]) for id in range(11)])

[17555, 13637, 8938, 6032, 4458, 3153, 2314, 1786, 1434, 1118, 806]


In [16]:
# random sample NUMBER_PER_GROUP users from each group
sampled_users = []
for gid in range(11):
    sampled_users.extend(random.sample(grouped_users[gid], NUMBER_PER_GROUP))

In [17]:
import json

output_dir = 'toy/processed'
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'sampled_users.json'), 'w') as f:
    json.dump(sampled_users, f)