In [12]:
import numpy as np
import torch
import gensim
from gensim import models
from tqdm import tqdm

In [4]:
import os

In [2]:
dataset_path = '../data/MP'

In [5]:
user2id = np.load(os.path.join(dataset_path, 'user_mapper.npy'), allow_pickle=True).item()
location2id = np.load(os.path.join(dataset_path, 'location_mapper.npy'), allow_pickle=True).item()

In [8]:
user_loc_matrix = np.zeros((len(user2id), len(location2id)))

In [9]:
user_loc_matrix.shape

(10000, 20607)

In [13]:
with open(os.path.join(dataset_path, f'train.csv'), 'r', encoding='utf8') as file:
    lines = file.readlines()
    for line in tqdm(lines, desc=f'Preprocess data'):
        stay_points = line.strip().split(',')[1:]
        user = line.strip().split(',')[0]
        for i in range(len(stay_points)):
            location, _ = stay_points[i].split('@')
            user_loc_matrix[user2id[user], location2id[location]] += 1

Preprocess data: 100%|██████████| 10000/10000 [00:01<00:00, 7568.56it/s]


In [14]:
num_users, num_locations = user_loc_matrix.shape
dictionary = gensim.corpora.Dictionary([[str(i)] for i in range(num_locations)])
corpus = []
for user in user_loc_matrix:
    user_doc = [str(loc) for loc, count in enumerate(user) for _ in range(int(count))]
    corpus.append(dictionary.doc2bow(user_doc))

In [17]:
dictionary.num_docs

20607

In [15]:
corpus[0]

[(247, 1),
 (1200, 1),
 (1341, 3),
 (1635, 1),
 (1981, 2),
 (2360, 1),
 (3780, 1),
 (4111, 1),
 (4813, 13),
 (5025, 2),
 (5820, 2),
 (6984, 1),
 (7401, 1),
 (7691, 1),
 (8452, 1),
 (8948, 1),
 (9376, 3),
 (10856, 4),
 (12901, 1),
 (13247, 46),
 (13601, 1),
 (14346, 18),
 (15995, 2),
 (16658, 1),
 (16741, 2),
 (17738, 3),
 (18974, 1),
 (19587, 2),
 (19921, 2),
 (20197, 1)]

In [18]:
topic_num = 450
print(f'Generating a probability distribution... topic: {topic_num}')
lda = models.LdaModel(corpus, num_topics=topic_num, random_state=42)

Generating a probability distribution... topic: 450


In [25]:
lda.print_topic(topicno=0)

'0.102*"15879" + 0.050*"4363" + 0.042*"10018" + 0.038*"11203" + 0.032*"8530" + 0.031*"12402" + 0.029*"4421" + 0.023*"16884" + 0.021*"8916" + 0.020*"19408"'

In [37]:
lda.print_topics(num_topics=-1)

[(0,
  '0.102*"15879" + 0.050*"4363" + 0.042*"10018" + 0.038*"11203" + 0.032*"8530" + 0.031*"12402" + 0.029*"4421" + 0.023*"16884" + 0.021*"8916" + 0.020*"19408"'),
 (1,
  '0.110*"4611" + 0.108*"17365" + 0.105*"10471" + 0.083*"19388" + 0.070*"12765" + 0.054*"1954" + 0.052*"11739" + 0.040*"15184" + 0.026*"2069" + 0.025*"4857"'),
 (2,
  '0.137*"3010" + 0.103*"17530" + 0.099*"16725" + 0.094*"1758" + 0.093*"19916" + 0.080*"6179" + 0.079*"11888" + 0.067*"4559" + 0.030*"13710" + 0.016*"14622"'),
 (3,
  '0.128*"4653" + 0.107*"7069" + 0.087*"10583" + 0.073*"7121" + 0.069*"8905" + 0.066*"14750" + 0.048*"5313" + 0.042*"16865" + 0.033*"5627" + 0.025*"11065"'),
 (4,
  '0.497*"3081" + 0.083*"20150" + 0.070*"5220" + 0.059*"475" + 0.041*"69" + 0.030*"9630" + 0.029*"11334" + 0.015*"11572" + 0.014*"19984" + 0.014*"3763"'),
 (5,
  '0.121*"17166" + 0.108*"5786" + 0.103*"9086" + 0.065*"13596" + 0.065*"9608" + 0.050*"18103" + 0.050*"3681" + 0.048*"362" + 0.037*"14598" + 0.030*"15043"'),
 (6,
  '0.121*"3074

In [38]:
user_topics = np.zeros((num_users, topic_num))
for i, user in enumerate(user_loc_matrix):
    user_doc = [str(loc) for loc, count in enumerate(user) for _ in range(int(count))]
    for item in lda[dictionary.doc2bow(user_doc)]:
        j = item[0]
        prob = item[1]
        # user_topic: user i从topic j采样的概率
        user_topics[i, j] = prob
np.save(os.path.join(dataset_path, f'user_topic_loc_{topic_num}.npy'),np.array(user_topics))

In [40]:
user_topics = np.load(os.path.join(dataset_path, 'user_topic_loc_450.npy'), allow_pickle=True)

In [41]:
type(user_topics)

numpy.ndarray

In [44]:
sum = 0
nonzero_indices = np.nonzero(user_topics[0])[0]
for idx in nonzero_indices:
    print(f"Index {idx}: {user_topics[0][idx]}")
    sum += user_topics[0][idx]
print(sum)

Index 59: 0.019908057525753975
Index 122: 0.015168470330536366
Index 131: 0.01228119432926178
Index 143: 0.02340230718255043
Index 187: 0.05154679715633392
Index 218: 0.5545382499694824
Index 258: 0.011096815578639507
Index 299: 0.13765831291675568
Index 326: 0.014963103458285332
Index 348: 0.05595025420188904
Index 430: 0.07887287437915802
0.9753864370286465
