In [1]:
import pandas as pd
import numpy as np
from seqrec import TIFUKNN
import seqrec.data

In [2]:
all_train_baskets = pd.read_csv("datasets/instacart30k/train_baskets.csv.gz")
all_validation_baskets = pd.read_csv("datasets/instacart30k/valid_baskets.csv")
all_test_baskets = pd.read_csv("datasets/instacart30k/test_baskets.csv") 

In [3]:
seed = 1311
num_users = 10000

np.random.seed(seed)

unique_user_ids = list(all_train_baskets.user_id.unique())
sampled_users = np.random.choice(unique_user_ids, num_users)
train_baskets = all_train_baskets[all_train_baskets.user_id.isin(sampled_users)]
validation_baskets = all_validation_baskets[all_validation_baskets.user_id.isin(sampled_users)]
test_baskets = all_test_baskets[all_test_baskets.user_id.isin(sampled_users)] 

seqrec.data.index_consecutive('user_id', [train_baskets, validation_baskets, test_baskets])
seqrec.data.index_consecutive('item_id', [train_baskets, validation_baskets, test_baskets])

In [4]:
len(train_baskets)

1367487

In [5]:
%%time
tifu = TIFUKNN(train_baskets, k=10, kplus=40)

Setup 2895.09
Rust reps 33475.296


116247
8430
--Creating transpose of R...


Neighbors 89291.03499999999
CPU times: user 41.4 s, sys: 44.9 s, total: 1min 26s
Wall time: 1min 47s


--Computing row norms...
--Configuring for top-k -- num_threads: 8; pinning? false;
--Scheduling parallel top-k computation...


In [6]:
tifu.retrieve_for(5)

[(7279, 0.44695767760276794),
 (5184, 0.4224347472190857),
 (8173, 0.4104841351509094),
 (5178, 0.38635870814323425),
 (2446, 0.3803006708621979),
 (2351, 0.3799600899219513),
 (8268, 0.37488076090812683),
 (8393, 0.3720211982727051),
 (1334, 0.3645762503147125),
 (7378, 0.3610772490501404),
 (3832, 0.3516826331615448),
 (2281, 0.34509581327438354),
 (4324, 0.33925291895866394),
 (6633, 0.33916357159614563),
 (2570, 0.33826830983161926),
 (3907, 0.3344752788543701),
 (837, 0.33420756459236145),
 (4050, 0.3338434100151062),
 (7189, 0.3302789628505707),
 (5018, 0.3290339410305023),
 (6840, 0.32792091369628906),
 (4940, 0.3270239531993866),
 (8162, 0.31851470470428467),
 (787, 0.3180810213088989),
 (7571, 0.3167879581451416),
 (6245, 0.3160063922405243),
 (2586, 0.3147020936012268),
 (3725, 0.3144998848438263),
 (3382, 0.3144665062427521),
 (5404, 0.3137041926383972),
 (2583, 0.3125326931476593),
 (8158, 0.312523752450943),
 (3881, 0.3124087452888489),
 (2333, 0.31074056029319763),
 (7175

In [7]:
tifu.predict(5, [neighbor for neighbor, _ in tifu.retrieve_for(5)[:10]], 5)

[9256, 5810, 15419, 14863, 31707]