In [1]:
import sys
sys.path.insert(0, "/workspace/recommendation")

In [2]:
from argparse import ArgumentParser
import pandas as pd
from load import implicit_load
import torch
import tqdm

In [3]:
MIN_RATINGS = 20
USER_COLUMN = 'user_id'
ITEM_COLUMN = 'item_id'

In [4]:
class _TestNegSampler:
    def __init__(self, train_ratings, nb_neg):
        self.nb_neg = nb_neg
        self.nb_users = int(train_ratings[:, 0].max()) + 1
        self.nb_items = int(train_ratings[:, 1].max()) + 1

        # compute unique ids for quickly created hash set and fast lookup
        ids = (train_ratings[:, 0] * self.nb_items) + train_ratings[:, 1]
        self.set = set(ids)

    def generate(self, batch_size=128*1024):
        users = torch.arange(0, self.nb_users).reshape([1, -1]).repeat([self.nb_neg, 1]).transpose(0, 1).reshape(-1)

        items = [-1] * len(users)

        random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()
        print('Generating validation negatives...')
        for idx, u in enumerate(tqdm.tqdm(users.tolist())):
            if not random_items:
                random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()
            j = random_items.pop()
            while u * self.nb_items + j in self.set:
                if not random_items:
                    random_items = torch.LongTensor(batch_size).random_(0, self.nb_items).tolist()
                j = random_items.pop()

            items[idx] = j
        items = torch.LongTensor(items)
        return items

In [16]:
df = implicit_load('/data/ml-20m/ratings.csv', sort=False)

print("Filtering out users with less than {} ratings".format(MIN_RATINGS))
grouped = df.groupby(USER_COLUMN)
df = grouped.filter(lambda x: len(x) >= MIN_RATINGS)

print("Mapping original user and item IDs to new sequential IDs")
df[USER_COLUMN], unique_users = pd.factorize(df[USER_COLUMN])
df[ITEM_COLUMN], unique_items = pd.factorize(df[ITEM_COLUMN])


20000263 ratings on 26744 items from 138493 users from 1995-01-09 11:46:44 to 2015-03-31 06:40:02
Filtering out users with less than 20 ratings
Mapping original user and item IDs to new sequential IDs


In [24]:
df[ITEM_COLUMN]

0               0
1               1
2               2
3               3
4               4
5               5
6               6
7               7
8               8
9               9
10             10
11             11
12             12
13             13
14             14
15             15
16             16
17             17
18             18
19             19
20             20
21             21
22             22
23             23
24             24
25             25
26             26
27             27
28             28
29             29
            ...  
20000233      960
20000234     3832
20000235      962
20000236     6121
20000237     3996
20000238     1944
20000239     2603
20000240      971
20000241     1945
20000242      972
20000243      974
20000244     1953
20000245     5180
20000246      988
20000247      989
20000248     9547
20000249     1001
20000250     3938
20000251     1805
20000252     1007
20000253     6156
20000254     1810
20000255     1812
20000256    11011
20000257  

In [17]:
unique_users

Int64Index([     1,      2,      3,      4,      5,      6,      7,      8,
                 9,     10,
            ...
            138484, 138485, 138486, 138487, 138488, 138489, 138490, 138491,
            138492, 138493],
           dtype='int64', length=138493)

In [18]:
unique_items

Int64Index([     2,     29,     32,     47,     50,    112,    151,    223,
               253,    260,
            ...
            104307, 106170, 106401, 113539, 118856, 121017, 121019, 121021,
            110167, 110510],
           dtype='int64', length=26744)

In [22]:
import pickle

with open('./mappings.pickle', 'wb') as handle:
    pickle.dump({"users": unique_users, "items": unique_items}, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [6]:
# Need to sort before popping to get last item
df.sort_values(by='timestamp', inplace=True)

# clean up data
del df['rating'], df['timestamp']
df = df.drop_duplicates() # assuming it keeps order

# now we have filtered and sorted by time data, we can split test data out
grouped_sorted = df.groupby(USER_COLUMN, group_keys=False)
test_data = grouped_sorted.tail(1).sort_values(by='user_id')
# need to pop for each group
train_data = grouped_sorted.apply(lambda x: x.iloc[:-1])



In [7]:
# Note: no way to keep reference training data ordering because use of python set and multi-process
# It should not matter since it will be later randomized again
# save train and val data that is fixed.
train_ratings = torch.from_numpy(train_data.values)
torch.save(train_ratings, './train_ratings.pt')
test_ratings = torch.from_numpy(test_data.values)
torch.save(test_ratings, './test_ratings.pt')



In [13]:
train_ratings.shape

torch.Size([19861770, 2])

In [14]:
test_ratings

tensor([[     0,     62],
        [     1,     15],
        [     2,    336],
        ...,
        [138490,    173],
        [138491,    204],
        [138492,    695]])

In [10]:
sampler = _TestNegSampler(train_ratings.cpu().numpy(), 100)  # using 100 negative samples
test_negs = sampler.generate().cuda()
test_negs = test_negs.reshape(-1, 100)
torch.save(test_negs, './test_negatives.pt')

  1%|          | 80696/13849300 [00:00<00:17, 806955.00it/s]

Generating validation negatives...


100%|██████████| 13849300/13849300 [00:14<00:00, 966810.44it/s]


In [11]:
test_negs.shape

torch.Size([138493, 100])

In [12]:
test_negs[1,]

tensor([14470, 19596,  5496,  7876, 13781, 18589, 13426, 10928,  4491,  9756,
         8371, 16324, 19714, 17614, 17084, 10339, 19900, 19954,  3198,  7704,
        21833, 24603, 22393, 16497, 12137, 10897, 16824,  6476, 19759,  9787,
         2589, 19789,  6598, 18668, 26078, 23213,  8732, 20727, 11042, 13098,
        19331,  2694,   774, 16017,  2733,  7195,  3234,  5478,  9518, 25528,
        12890, 20064,  4193, 24937, 21779, 20982,  4279, 13174,  2057,  8464,
         4302, 18896,   546, 17086,  3973, 15116, 24690, 23495, 15982,  9509,
        11061, 16351,  5154, 13412, 18309, 12249,  3764, 22858, 25954,  1904,
         7456, 24602, 26063, 15207, 18617,  9906, 21567,  7472, 17297,  6302,
         9482, 15818,  6989, 14464,  2344, 10292, 23338,  6301, 22121,  7240],
       device='cuda:0')