In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [2]:
# External dependencies
import os

import cudf  # cuDF is an implementation of Pandas-like Dataframe on GPU
# import rmm

import numpy as np
import nvtabular as nvt

from sklearn.model_selection import train_test_split

In [3]:
INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("./data/")
)

## Read Movie and Rating Features

In [4]:
movie_features = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "movie_features.parquet"))
movie_features.head()

Unnamed: 0,genres,movieId,tags_unique,tags_nunique
0,"[3, 4, 5, 6, 10]",1,"[477, 581, 640, 1857, 2175, 2817, 3538, 4395, ...",126
1,"[3, 5, 10]",2,"[1206, 4448, 5069, 5213, 7883, 8912, 9116, 925...",44
2,"[6, 16]",3,"[2196, 4356, 4448, 6484, 11460, 12274, 17647, ...",23
3,"[6, 9, 16]",4,"[4448, 31525, 34749, 34981, 39134, 48169, 6086...",8
4,[6],5,"[4448, 5693, 6977, 8912, 23756, 25354, 28604, ...",20


In [5]:
ratings_features = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "ratings_features.parquet"))
ratings_features.head()

Unnamed: 0,day,interaction,userId,movieId,rating,timestamp
0,4146,True,1,296,5.0,1147880044
1,4146,True,1,306,3.5,1147868817
2,4146,True,1,307,5.0,1147868828
3,4146,True,1,665,5.0,1147878820
4,4146,True,1,899,3.5,1147868510


## Join Ratings With Movie Features

In [6]:
joined_features = movie_features.merge(ratings_features)
joined_features.head()

Unnamed: 0,genres,movieId,tags_unique,tags_nunique,day,interaction,userId,rating,timestamp
0,"[3, 4, 5, 6, 10]",1,"[477, 581, 640, 1857, 2175, 2817, 3538, 4395, ...",126,2728,True,188,5.0,1025333097
1,"[2, 3, 10]",44199,"[2385, 18428, 24178, 31375, 32524, 37622, 3849...",19,4715,True,155,4.5,1196999137
2,[1],45186,[46781],1,4715,True,155,3.5,1196999134
3,"[3, 4, 5, 6, 10]",1,"[477, 581, 640, 1857, 2175, 2817, 3538, 4395, ...",126,4449,True,160,4.5,1174081897
4,"[2, 3, 18]",10,"[187, 188, 189, 2817, 3416, 3862, 3911, 4448, ...",66,4449,True,160,4.0,1174082605


In [12]:
cd

TypeError: issubclass() arg 1 must be a class

In [7]:
sampled_indices = np.array([np.random.randint(0,nunique) for nunique in joined_features["tags_nunique"].to_pandas()], dtype=np.int32)

In [8]:
sampled_tags = np.array([tags[index] for index, tags in zip(sampled_indices, joined_features["tags_unique"].to_pandas())])

In [9]:
joined_features["sampled_tag"] = sampled_tags

In [10]:
joined_features = joined_features.drop(labels=["tags_unique", "tags_nunique", "interaction", "rating"], axis=1)

In [11]:
joined_features = joined_features.sort_values("timestamp")

In [12]:
joined_features.head()

Unnamed: 0,genres,movieId,day,userId,timestamp,sampled_tag
158499,"[15, 18]",47,0,2262,789652009,53451
999500,"[9, 15]",57,385,13424,822873600,21387
999522,"[6, 9, 16]",11,385,13424,822873600,44344
3285717,"[15, 17, 18]",32,385,42937,822873600,29536
7845053,[6],18,385,102689,822873600,45249


In [13]:
grouped_examples = joined_features.groupby(["userId", "day"]).agg(
    {
        "sampled_tag": "collect",
        "genres": "collect",
        "timestamp": "collect",
        "movieId": ["collect", "count"]
    }
)

In [14]:
grouped_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,sampled_tag,genres,timestamp,movieId,movieId
Unnamed: 0_level_1,Unnamed: 1_level_1,collect,collect,collect,collect,count
userId,day,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
1,4146,"[58406, 45148, 1688, 8795, 61752, 3545, 3500, ...","[[6], [6], [3, 6, 17], [9, 16], [3, 9], [2, 3,...","[1147868053, 1147868097, 1147868414, 114786846...","[5952, 1653, 1250, 6539, 6377, 3448, 1088, 899...",53
2,4071,"[52780, 65105, 38902, 64349, 1452, 57559, 3972...","[[6], [6], [6], [6], [6, 9], [9], [7, 9, 15], ...","[1141415528, 1141415566, 1141415576, 114141558...","[5952, 497, 1374, 1653, 2640, 5445, 151, 236, ...",125
3,7521,"[8568, 52357, 4576, 24974, 28670, 24770, 6610,...","[[6, 9, 16], [9], [2, 6, 7, 9, 17], [3, 4, 5, ...","[1439472199, 1439472203, 1439472211, 143947221...","[356, 593, 1270, 1, 2571, 260, 318, 1196, 527,...",179
3,7688,"[67417, 71080, 30253, 55981, 25864, 22485, 467...","[[12], [12, 18], [7, 9, 18], [7, 9], [6, 9], [...","[1453904021, 1453904031, 1453904046, 145390404...","[1206, 1208, 44191, 32587, 40815, 36529, 45186...",10
3,8045,"[6042, 50075, 32067, 6526, 57987, 30637, 56885...","[[7, 9, 17, 18], [7, 9, 18], [7, 9], [9], [7, ...","[1484753654, 1484753766, 1484753808, 148475384...","[1089, 4011, 741, 778, 111, 214, 293, 1252, 33...",22


In [15]:
grouped_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,sampled_tag,genres,timestamp,movieId,movieId
Unnamed: 0_level_1,Unnamed: 1_level_1,collect,collect,collect,collect,count
userId,day,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
1,4146,"[58406, 45148, 1688, 8795, 61752, 3545, 3500, ...","[[6], [6], [3, 6, 17], [9, 16], [3, 9], [2, 3,...","[1147868053, 1147868097, 1147868414, 114786846...","[5952, 1653, 1250, 6539, 6377, 3448, 1088, 899...",53
2,4071,"[52780, 65105, 38902, 64349, 1452, 57559, 3972...","[[6], [6], [6], [6], [6, 9], [9], [7, 9, 15], ...","[1141415528, 1141415566, 1141415576, 114141558...","[5952, 497, 1374, 1653, 2640, 5445, 151, 236, ...",125
3,7521,"[8568, 52357, 4576, 24974, 28670, 24770, 6610,...","[[6, 9, 16], [9], [2, 6, 7, 9, 17], [3, 4, 5, ...","[1439472199, 1439472203, 1439472211, 143947221...","[356, 593, 1270, 1, 2571, 260, 318, 1196, 527,...",179
3,7688,"[67417, 71080, 30253, 55981, 25864, 22485, 467...","[[12], [12, 18], [7, 9, 18], [7, 9], [6, 9], [...","[1453904021, 1453904031, 1453904046, 145390404...","[1206, 1208, 44191, 32587, 40815, 36529, 45186...",10
3,8045,"[6042, 50075, 32067, 6526, 57987, 30637, 56885...","[[7, 9, 17, 18], [7, 9, 18], [7, 9], [9], [7, ...","[1484753654, 1484753766, 1484753808, 148475384...","[1089, 4011, 741, 778, 111, 214, 293, 1252, 33...",22


In [16]:
len(grouped_examples[grouped_examples[("movieId", "count")] > 1])

498253

In [17]:
grouped_examples["target_item"] = grouped_examples[("movieId", "collect")].list.get(-1)

In [18]:
grouped_examples[("sampled_tag", "collect")] = \
    np.array([values[:-1] for values in grouped_examples[("sampled_tag", "collect")].to_pandas()], dtype=object)

In [19]:
grouped_examples[("genres", "collect")] = \
    np.array([values[:-1] for values in grouped_examples[("genres", "collect")].to_pandas()], dtype=object)

In [20]:
grouped_examples[("timestamp", "collect")] = \
    np.array([values[:-1] for values in grouped_examples[("timestamp", "collect")].to_pandas()], dtype=object)

In [21]:
grouped_examples[("movieId", "collect")] = \
    np.array([values[:-1] for values in grouped_examples[("movieId", "collect")].to_pandas()], dtype=object)

In [22]:
grouped_examples[("movieId", "count")] = grouped_examples[("movieId", "count")] - 1

In [23]:
def flatten_genres(g):
    if len(g) > 0:
        return np.concatenate(g).ravel()
    else:
        return []

grouped_examples[("genres", "collect")] = \
    grouped_examples[("genres", "collect")].to_pandas().map(flatten_genres)

In [24]:
grouped_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,sampled_tag,genres,timestamp,movieId,movieId,target_item
Unnamed: 0_level_1,Unnamed: 1_level_1,collect,collect,collect,collect,count,Unnamed: 7_level_1
userId,day,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
1,4146,"[58406, 45148, 1688, 8795, 61752, 3545, 3500, ...","[6, 6, 3, 6, 17, 9, 16, 3, 9, 2, 3, 6, 10, 3, ...","[1147868053, 1147868097, 1147868414, 114786846...","[5952, 1653, 1250, 6539, 6377, 3448, 1088, 899...",52,7361
2,4071,"[52780, 65105, 38902, 64349, 1452, 57559, 3972...","[6, 6, 6, 6, 6, 9, 9, 7, 9, 15, 7, 9, 2, 6, 6,...","[1141415528, 1141415566, 1141415576, 114141558...","[5952, 497, 1374, 1653, 2640, 5445, 151, 236, ...",124,2150
3,7521,"[8568, 52357, 4576, 24974, 28670, 24770, 6610,...","[6, 9, 16, 9, 2, 6, 7, 9, 17, 3, 4, 5, 6, 10, ...","[1439472199, 1439472203, 1439472211, 143947221...","[356, 593, 1270, 1, 2571, 260, 318, 1196, 527,...",178,37729
3,7688,"[67417, 71080, 30253, 55981, 25864, 22485, 467...","[12, 12, 18, 7, 9, 18, 7, 9, 6, 9, 1, 1, 12, 2...","[1453904021, 1453904031, 1453904046, 145390404...","[1206, 1208, 44191, 32587, 40815, 36529, 45186...",9,4344
3,8045,"[6042, 50075, 32067, 6526, 57987, 30637, 56885...","[7, 9, 17, 18, 7, 9, 18, 7, 9, 9, 7, 9, 9, 18,...","[1484753654, 1484753766, 1484753808, 148475384...","[1089, 4011, 741, 778, 111, 214, 293, 1252, 33...",21,27773


In [25]:
grouped_examples.columns

MultiIndex([('sampled_tag', 'collect'),
            (     'genres', 'collect'),
            (  'timestamp', 'collect'),
            (    'movieId', 'collect'),
            (    'movieId',   'count'),
            ('target_item',        '')],
           )

In [26]:
grouped_examples.columns = ["_".join(list(column_name)) for column_name in grouped_examples.columns]
grouped_examples.columns = [column_name.replace("_collect", "") for column_name in grouped_examples.columns]

In [27]:
grouped_examples["target_item"] = grouped_examples["target_item_"]
grouped_examples = grouped_examples.drop(labels="target_item_", axis=1)
grouped_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,sampled_tag,genres,timestamp,movieId,movieId_count,target_item
userId,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,4146,"[58406, 45148, 1688, 8795, 61752, 3545, 3500, ...","[6, 6, 3, 6, 17, 9, 16, 3, 9, 2, 3, 6, 10, 3, ...","[1147868053, 1147868097, 1147868414, 114786846...","[5952, 1653, 1250, 6539, 6377, 3448, 1088, 899...",52,7361
2,4071,"[52780, 65105, 38902, 64349, 1452, 57559, 3972...","[6, 6, 6, 6, 6, 9, 9, 7, 9, 15, 7, 9, 2, 6, 6,...","[1141415528, 1141415566, 1141415576, 114141558...","[5952, 497, 1374, 1653, 2640, 5445, 151, 236, ...",124,2150
3,7521,"[8568, 52357, 4576, 24974, 28670, 24770, 6610,...","[6, 9, 16, 9, 2, 6, 7, 9, 17, 3, 4, 5, 6, 10, ...","[1439472199, 1439472203, 1439472211, 143947221...","[356, 593, 1270, 1, 2571, 260, 318, 1196, 527,...",178,37729
3,7688,"[67417, 71080, 30253, 55981, 25864, 22485, 467...","[12, 12, 18, 7, 9, 18, 7, 9, 6, 9, 1, 1, 12, 2...","[1453904021, 1453904031, 1453904046, 145390404...","[1206, 1208, 44191, 32587, 40815, 36529, 45186...",9,4344
3,8045,"[6042, 50075, 32067, 6526, 57987, 30637, 56885...","[7, 9, 17, 18, 7, 9, 18, 7, 9, 9, 7, 9, 9, 18,...","[1484753654, 1484753766, 1484753808, 148475384...","[1089, 4011, 741, 778, 111, 214, 293, 1252, 33...",21,27773


In [28]:
grouped_examples["genre"] = grouped_examples["genres"]
grouped_examples = grouped_examples.drop(labels="genres", axis=1)
grouped_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,sampled_tag,timestamp,movieId,movieId_count,target_item,genre
userId,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,4146,"[58406, 45148, 1688, 8795, 61752, 3545, 3500, ...","[1147868053, 1147868097, 1147868414, 114786846...","[5952, 1653, 1250, 6539, 6377, 3448, 1088, 899...",52,7361,"[6, 6, 3, 6, 17, 9, 16, 3, 9, 2, 3, 6, 10, 3, ..."
2,4071,"[52780, 65105, 38902, 64349, 1452, 57559, 3972...","[1141415528, 1141415566, 1141415576, 114141558...","[5952, 497, 1374, 1653, 2640, 5445, 151, 236, ...",124,2150,"[6, 6, 6, 6, 6, 9, 9, 7, 9, 15, 7, 9, 2, 6, 6,..."
3,7521,"[8568, 52357, 4576, 24974, 28670, 24770, 6610,...","[1439472199, 1439472203, 1439472211, 143947221...","[356, 593, 1270, 1, 2571, 260, 318, 1196, 527,...",178,37729,"[6, 9, 16, 9, 2, 6, 7, 9, 17, 3, 4, 5, 6, 10, ..."
3,7688,"[67417, 71080, 30253, 55981, 25864, 22485, 467...","[1453904021, 1453904031, 1453904046, 145390404...","[1206, 1208, 44191, 32587, 40815, 36529, 45186...",9,4344,"[12, 12, 18, 7, 9, 18, 7, 9, 6, 9, 1, 1, 12, 2..."
3,8045,"[6042, 50075, 32067, 6526, 57987, 30637, 56885...","[1484753654, 1484753766, 1484753808, 148475384...","[1089, 4011, 741, 778, 111, 214, 293, 1252, 33...",21,27773,"[7, 9, 17, 18, 7, 9, 18, 7, 9, 9, 7, 9, 9, 18,..."


In [29]:
# TODO: We need another categorify here to handle the multi-hot features

In [30]:
grouped_examples.to_parquet(os.path.join(INPUT_DATA_DIR, "grouped_examples.parquet"))