In [1]:
from collections import defaultdict
from datetime import datetime
import torch
import numpy as np
import pandas as pd

In [4]:
ml_1m_path = '../../../../datasets/ml-1m'

In [6]:
file_path = ml_1m_path + '/ratings.dat'
# data = pd.read_csv(file_path, encoding='latin-1', sep=',', engine='python', index_col='userId')
data = pd.read_csv(file_path, sep='::', engine='python',
                   names=['userId', 'movieId', 'rating', 'timestamp'], index_col='userId')

In [7]:
data

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1193,5,978300760
1,661,3,978302109
1,914,3,978301968
1,3408,4,978300275
1,2355,5,978824291
...,...,...,...
6040,1091,1,956716541
6040,1094,5,956704887
6040,562,5,956704746
6040,1096,4,956715648


In [8]:
actions = data
actions = actions.groupby('movieId').filter(lambda  x: len(x) >= 5)
actions = actions.groupby('userId').filter(lambda  x: len(x) >= 5)

In [9]:
actions = actions.groupby('userId', group_keys=False).apply(lambda  x: x.sort_values('timestamp'))

In [10]:
actions

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,3186,4,978300019
1,1721,4,978300055
1,1022,5,978300055
1,1270,5,978300055
1,2340,3,978300103
...,...,...,...
6040,2917,4,997454429
6040,1784,3,997454464
6040,1921,4,997454464
6040,161,3,997454486


In [11]:
# data = data.groupby('userId')
min = 100000
max = 0
total = 0
for i in range(1, actions.index.max()):
    total += len(actions.loc[i])
    if len(data.loc[i]) < min:
        min = len(actions.loc[i])
    if len(data.loc[i]) > max:
        max = len(actions.loc[i])
min, max, total/actions.index.max()

(20, 2277, 165.44205298013244)

In [12]:
usermap = dict()
usernum = 0
itemmap = dict()
itemnum = 0
# reorder the userid and itemid (keep the same step with original SASRec code)
for _id, row in actions.iterrows():
    if _id in usermap:
        userid = usermap[_id]
    else:
        usernum += 1
        userid = usernum
        usermap[_id] = userid

    if row.movieId in itemmap:
        itemid = itemmap[row.movieId]
    else:
        itemnum += 1
        itemid = itemnum
        itemmap[row.movieId] = itemid

In [13]:
usermap.__len__(), itemmap.__len__()

(6040, 3416)

In [14]:
actions['movieId'] = actions['movieId'].map(itemmap)
actions.index = actions.index.map(usermap)

In [15]:
actions

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1,4,978300019
1,2,4,978300055
1,3,5,978300055
1,4,5,978300055
1,5,3,978300103
...,...,...,...
6040,1249,4,997454429
6040,88,3,997454464
6040,371,4,997454464
6040,464,3,997454486


In [16]:
userMaxTime = actions.groupby('userId').timestamp.max()

In [17]:
userMaxTime

userId
1       978824351
2       978300174
3       978298504
4       978294282
5       978246585
          ...    
6036    956755196
6037    956801840
6038    956717204
6039    956758029
6040    998315055
Name: timestamp, Length: 6040, dtype: int64

In [18]:
day = 86400
num_day = 14
userSplitTime = userMaxTime - num_day * day

In [19]:
actions

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1,4,978300019
1,2,4,978300055
1,3,5,978300055
1,4,5,978300055
1,5,3,978300103
...,...,...,...
6040,1249,4,997454429
6040,88,3,997454464
6040,371,4,997454464
6040,464,3,997454486


In [20]:
userSplitTime

userId
1       977614751
2       977090574
3       977088904
4       977084682
5       977036985
          ...    
6036    955545596
6037    955592240
6038    955507604
6039    955548429
6040    997105455
Name: timestamp, Length: 6040, dtype: int64

In [21]:
def filter_input(group):
    user_id = group.index[0]
    limit = userSplitTime.loc[user_id]
    return group[group['timestamp'] < limit]

def filter_target(group):
    user_id = group.index[0]
    limit = userSplitTime.loc[user_id]
    return group[group['timestamp'] >= limit]

In [64]:
input_data = actions.groupby('userId').apply(filter_input)
target_window = actions.groupby('userId').apply(filter_target)

In [65]:
input_data = input_data.reset_index(level=1, drop=True)

In [24]:
input_data

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10,552,3,978224375
10,553,3,978224375
10,430,4,978224375
10,554,4,978224400
10,56,4,978224400
...,...,...,...
6040,515,4,964828782
6040,2799,2,964828782
6040,984,5,964828799
6040,183,4,964828900


In [66]:
target_window = target_window.reset_index(level=1, drop=True)

In [26]:
target_window

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1,4,978300019
1,2,4,978300055
1,3,5,978300055
1,4,5,978300055
1,5,3,978300103
...,...,...,...
6040,1249,4,997454429
6040,88,3,997454464
6040,371,4,997454464
6040,464,3,997454486


In [67]:
input_idx = set(input_data.index)
target_idx = set(target_window.index)
valid_userid = input_idx.intersection(target_idx)

In [28]:
len(valid_userid)

1628

In [49]:
valid_user = sorted(list(valid_userid))
valid_user

[10,
 19,
 20,
 22,
 23,
 24,
 36,
 44,
 59,
 63,
 65,
 68,
 73,
 74,
 76,
 80,
 85,
 86,
 89,
 90,
 92,
 96,
 97,
 99,
 102,
 104,
 114,
 116,
 122,
 123,
 124,
 127,
 131,
 133,
 134,
 137,
 140,
 142,
 146,
 148,
 149,
 150,
 151,
 153,
 157,
 160,
 162,
 164,
 166,
 169,
 173,
 175,
 180,
 183,
 184,
 188,
 192,
 193,
 195,
 198,
 201,
 204,
 208,
 215,
 218,
 224,
 228,
 229,
 231,
 234,
 235,
 237,
 239,
 242,
 248,
 259,
 264,
 267,
 270,
 271,
 279,
 285,
 293,
 299,
 302,
 303,
 306,
 308,
 310,
 311,
 314,
 319,
 321,
 322,
 326,
 329,
 330,
 331,
 332,
 333,
 338,
 343,
 349,
 351,
 355,
 356,
 362,
 366,
 368,
 372,
 375,
 376,
 382,
 387,
 391,
 392,
 398,
 402,
 403,
 404,
 405,
 408,
 411,
 412,
 415,
 419,
 420,
 422,
 424,
 438,
 439,
 442,
 453,
 454,
 462,
 480,
 482,
 487,
 496,
 498,
 500,
 507,
 514,
 516,
 518,
 519,
 520,
 528,
 529,
 531,
 549,
 566,
 582,
 583,
 584,
 587,
 588,
 593,
 595,
 600,
 608,
 609,
 611,
 623,
 624,
 626,
 633,
 635,
 639,
 641,
 648

In [30]:
len(valid_user)

1628

In [31]:
len(input_data), len(target_window)

(417212, 582399)

In [32]:
valid_user

[10,
 19,
 20,
 22,
 23,
 24,
 36,
 44,
 59,
 63,
 65,
 68,
 73,
 74,
 76,
 80,
 85,
 86,
 89,
 90,
 92,
 96,
 97,
 99,
 102,
 104,
 114,
 116,
 122,
 123,
 124,
 127,
 131,
 133,
 134,
 137,
 140,
 142,
 146,
 148,
 149,
 150,
 151,
 153,
 157,
 160,
 162,
 164,
 166,
 169,
 173,
 175,
 180,
 183,
 184,
 188,
 192,
 193,
 195,
 198,
 201,
 204,
 208,
 215,
 218,
 224,
 228,
 229,
 231,
 234,
 235,
 237,
 239,
 242,
 248,
 259,
 264,
 267,
 270,
 271,
 279,
 285,
 293,
 299,
 302,
 303,
 306,
 308,
 310,
 311,
 314,
 319,
 321,
 322,
 326,
 329,
 330,
 331,
 332,
 333,
 338,
 343,
 349,
 351,
 355,
 356,
 362,
 366,
 368,
 372,
 375,
 376,
 382,
 387,
 391,
 392,
 398,
 402,
 403,
 404,
 405,
 408,
 411,
 412,
 415,
 419,
 420,
 422,
 424,
 438,
 439,
 442,
 453,
 454,
 462,
 480,
 482,
 487,
 496,
 498,
 500,
 507,
 514,
 516,
 518,
 519,
 520,
 528,
 529,
 531,
 549,
 566,
 582,
 583,
 584,
 587,
 588,
 593,
 595,
 600,
 608,
 609,
 611,
 623,
 624,
 626,
 633,
 635,
 639,
 641,
 648

In [68]:
filted_actions = actions.loc[valid_user]
# input_data = input_data.loc[valid_user]
# target_window = target_window.loc[valid_user]

In [69]:
input_data

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10,552,3,978224375
10,553,3,978224375
10,430,4,978224375
10,554,4,978224400
10,56,4,978224400
...,...,...,...
6040,515,4,964828782
6040,2799,2,964828782
6040,984,5,964828799
6040,183,4,964828900


In [54]:
target_window

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10,765,3,979775053
10,766,3,979775054
10,60,4,979775131
10,268,3,979775159
10,767,3,979775159
...,...,...,...
6040,1249,4,997454429
6040,88,3,997454464
6040,371,4,997454464
6040,464,3,997454486


In [51]:
filted_actions

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10,552,3,978224375
10,553,3,978224375
10,430,4,978224375
10,554,4,978224400
10,56,4,978224400
...,...,...,...
6040,1249,4,997454429
6040,88,3,997454464
6040,371,4,997454464
6040,464,3,997454486


In [70]:
# map the user and item
usermap = dict()
usernum = 0
itemmap = dict()
itemnum = 0
# reorder the userid and itemid (keep the same step with original SASRec code)
for _id, row in filted_actions.iterrows():
    if _id in usermap:
        userid = usermap[_id]
    else:
        usernum += 1
        userid = usernum
        usermap[_id] = userid

    if row.movieId in itemmap:
        itemid = itemmap[row.movieId]
    else:
        itemnum += 1
        itemid = itemnum
        itemmap[row.movieId] = itemid

In [71]:
input_data['movieId'] = input_data['movieId'].map(itemmap)
input_data.index = input_data.index.map(usermap)

In [72]:
target_window['movieId'] = target_window['movieId'].map(itemmap)
target_window.index = target_window.index.map(usermap)

In [37]:
list(usermap.values())

[1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,
 185

In [73]:
input_data

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1,3,978224375
1,2,3,978224375
1,3,4,978224375
1,4,4,978224400
1,5,4,978224400
...,...,...,...
1628,7,4,964828782
1628,2391,2,964828782
1628,755,5,964828799
1628,976,4,964828900


In [74]:
target_window

Unnamed: 0_level_0,movieId,rating,timestamp
userId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
,1297.0,4,978300019
,311.0,4,978300055
,386.0,5,978300055
,37.0,5,978300055
,1530.0,3,978300103
...,...,...,...
1628.0,734.0,4,997454429
1628.0,357.0,3,997454464
1628.0,145.0,4,997454464
1628.0,579.0,3,997454486


Int64Index([    1,     1,     1,     1,     1,     1,     1,     1,     1,
                1,
            ...
            39019, 39019, 39019, 39019, 39019, 39019, 39019, 39019, 39019,
            39019],
           dtype='int64', name='userId', length=10504095)

In [61]:
with open('ml-1m_train.txt', 'w') as f:
    for _id in list(usermap.values()):
        movie_id_data = input_data.loc[_id].movieId
        if isinstance(movie_id_data, pd.Series):
            for movie_id in movie_id_data:
                f.write('%d %d\n' % (_id, movie_id))
        else:  # it's a single value, not a list
            f.write('%d %d\n' % (_id, movie_id_data))

In [75]:
with open('ml-1m_target.txt', 'w') as f:
    for _id in list(usermap.values()):
        movie_id_data = target_window.loc[_id].movieId
        if isinstance(movie_id_data, pd.Series):
            for movie_id in movie_id_data:
                f.write('%d %d\n' % (_id, movie_id))
        else:  # it's a single value, not a list
            f.write('%d %d\n' % (_id, movie_id_data))

# DATA reading

In [125]:
def data_partition_window_InputTarget_byP(f_train, f_target):
    usernum = 0
    itemnum = 0
    user_input = defaultdict(list)
    user_target = defaultdict(list)
    f = open('%s.txt' % f_train, 'r')
    # read from each line
    for line in f:
        u, i = line.rstrip().split(' ')
        u = int(u)
        i = int(i)
        usernum = max(u, usernum)
        itemnum = max(i, itemnum)
        user_input[u].append(i)
        # count user and items
    f = open('%s.txt' % f_target, 'r')
    # read from each line
    for line in f:
        u, i = line.rstrip().split(' ')
        u = int(u)
        i = int(i)
        itemnum = max(i, itemnum)
        user_target[u].append(i)
        # count user and items

    return [user_input, user_target, usernum, itemnum]

[_,_,a,b] = data_partition_window_InputTarget_byP('ml-20m_train', 'ml-20m_target')

AxisError: axis 1 is out of bounds for array of dimension 0