In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [2]:
import os

import cudf
import numpy as np
import rmm

import nvtabular as nvt

In [3]:
rmm.reinitialize(managed_memory=True)

In [4]:
INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("./data/")
)

# Read the Retrieval Training Examples

In [5]:
retrieval_training = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "retrieval_training.parquet"))
retrieval_training.columns = [f"user_{c}" if not c == "target_item" else "movie_id" for c in retrieval_training.columns]
retrieval_training.reset_index(inplace=True)
retrieval_training = retrieval_training.set_index("movie_id")
retrieval_training.head()

Unnamed: 0_level_0,user_id,day,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
7237,1,4146,"[40280, 34864, 60921, 55756, 53948, 968, 2649,...","[3, 10, 9, 17, 18, 3, 9, 19, 3, 4, 5, 6, 6, 9,...","[1147868053, 1147868097, 1147868414, 114786846...","[5841, 1592, 1218, 6259, 3353, 1062, 6589, 384...",38
2061,2,4071,"[70987, 21602, 55885, 217, 26361, 361, 38094, ...","[3, 10, 6, 16, 2, 3, 17, 18, 9, 17, 18, 2, 3, ...","[1141415528, 1141415566, 1141415576, 114141558...","[5841, 493, 1339, 1592, 2550, 150, 234, 4781, ...",112
24542,3,7521,"[63271, 34088, 581, 28225, 491, 34656, 41947, ...","[6, 9, 16, 19, 7, 12, 18, 3, 4, 5, 6, 10, 2, 1...","[1439472199, 1439472203, 1439472215, 143947222...","[352, 586, 1, 2481, 258, 315, 1167, 523, 12217...",247
4240,3,7688,"[69734, 68604, 71594, 56924, 32424, 31525, 668...","[7, 9, 17, 18, 2, 9, 19, 6, 16, 8, 20, 6, 10, ...","[1453904021, 1453904031, 1453904046, 145390404...","[1176, 1178, 10678, 9777, 11446, 11930, 10407,...",17
9335,3,8045,"[38644, 43776, 43440, 57466, 45557, 13968, 170...","[7, 15, 18, 2, 18, 6, 7, 18, 4, 17, 6, 7, 9, 7...","[1484753654, 1484753762, 1484753766, 148475380...","[1063, 29365, 3908, 726, 763, 110, 213, 29375,...",30


In [6]:
# Join movie features on to the positive examples from the retrieval training data

In [7]:
movie_features = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "movie_features.parquet"))
movie_features.reset_index(inplace=True)
movie_features = movie_features.set_index("movie_id")
movie_features = movie_features.drop(labels=["index", "datetime", "created"], axis=1)
movie_features.columns = [f"movie_{c}" for c in movie_features.columns]
movie_features.head()

Unnamed: 0_level_0,movie_tags_unique,movie_genres,movie_tags_nunique
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
27265,"[40345, 59079]",[9],2
27273,[28414],[8],1
27266,[32292],[9],1
27282,"[3365, 33048, 43053, 46467, 50807, 51310, 5739...","[9, 18]",10
27290,"[31336, 34155, 42699, 48919, 48957, 51656, 602...",[1],8


In [8]:
positive_interactions = retrieval_training.join(movie_features, lsuffix="_user", rsuffix="_movie")
positive_interactions.reset_index(inplace=True)
positive_interactions = positive_interactions.set_index(["user_id", "day"])
positive_interactions

Unnamed: 0_level_0,Unnamed: 1_level_0,movie_id,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
4361,7803,12327,"[51520, 43758, 70349, 22247, 43755, 18446, 0, ...","[6, 2, 7, 9, 1, 6, 9, 16, 19, 8, 9, 5, 9, 6, 1...","[1463814550, 1463814624, 1463814628, 146381465...","[24552, 35337, 28744, 352, 14926, 12320, 14018...",20,"[5187, 13991, 14352, 16982, 19408, 21201, 2234...","[6, 16]",26
4371,6850,1641,"[61112, 39183, 71328, 52690, 34041, 16962, 281...","[9, 20, 9, 9, 16, 2, 3, 9, 10, 9, 12, 12, 17, ...","[1381461702, 1381461713, 1381461744, 138146177...","[11828, 15228, 18981, 14619, 14500, 17042, 187...",43,"[256, 2817, 3281, 3284, 3965, 5048, 6191, 9537...","[9, 16]",119
4366,628,1073,"[34811, 35454, 42323, 56711, 48406, 33143, 455...","[18, 7, 12, 18, 6, 9, 16, 19, 7, 9, 2, 3, 17, ...","[843927996, 843927996, 843928045, 843928045, 8...","[453, 586, 352, 315, 476, 251, 286, 363, 523, ...",13,"[2816, 4903, 5274, 6223, 7971, 8092, 9853, 276...","[5, 9, 10]",35
4375,4888,12276,"[36554, 37239, 41828, 71030, 636, 62981, 118, ...","[2, 9, 18, 12, 17, 18, 5, 9, 9, 2, 9, 17, 2, 1...","[1211953937, 1211954021, 1211954028, 121195403...","[7300, 948, 527, 1067, 2571, 822, 11116, 10803...",71,"[2364, 10398, 39197, 41467, 51205, 53886]","[8, 14]",6
4362,3725,1804,"[50798, 53842, 69382, 1076, 60344, 42604, 1073...","[9, 17, 9, 15, 16, 18, 9, 9, 14, 16, 9, 16, 6,...","[1111493534, 1111493678, 1111493688, 111149369...","[1528, 1541, 3068, 1062, 2767, 590, 586, 4204,...",9,"[2030, 6659, 16810, 22746, 24933, 26075, 30064...",[18],28
...,...,...,...,...,...,...,...,...,...,...
161748,1774,586,"[46366, 26485, 66429, 59237, 29634, 28564, 714...","[9, 9, 2, 3, 6, 10, 16, 9, 9, 6, 6, 16, 9, 16,...","[942932313, 942932528, 942932528, 942932528, 9...","[2223, 1195, 1168, 1873, 1869, 2827, 1274, 193...",40,"[269, 439, 457, 927, 965, 968, 969, 1135, 2302...","[7, 12, 18]",287
161731,879,1202,"[35723, 46964, 22448, 28139, 28007, 34424, 185...","[3, 4, 5, 6, 10, 9, 16, 7, 9, 6, 7, 9, 18, 6, ...","[865619550, 865619550, 865619588, 865619589, 8...","[1, 17, 36, 601, 52, 768, 26, 785, 29, 85, 714...",12,"[685, 2817, 3219, 3538, 3545, 6224, 10021, 100...","[2, 9, 19]",73
161743,8378,619,"[68198, 36596, 6022, 70820, 54217, 58063, 5354...","[9, 18, 6, 9, 16, 6, 9, 16, 7, 9, 12, 15, 18, ...","[1513470181, 1513470199, 1513470217, 151347024...","[4132, 2779, 214, 21653, 2836, 22, 20456, 4557...",42,"[4448, 6003, 7199, 8084, 8085, 10561, 15746, 1...","[7, 9, 15, 18]",55
161738,2033,3517,"[815, 22765, 34984, 35510, 26425, 47665, 64248...","[9, 19, 9, 16, 4, 5, 6, 9, 18, 2, 3, 9, 7, 9, ...","[965328135, 965328135, 965328241, 965328241, 9...","[523, 1554, 3651, 3655, 3480, 3500, 3525, 3653]",8,"[685, 7474, 29463, 29481, 30197, 34392, 36173,...","[3, 4, 5]",21


In [9]:
# TODO: Drop the target_item column and join on the negative examples from earlier

In [10]:
negative_ratings = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "negative_ratings.parquet"))
negative_ratings = negative_ratings.set_index(["user_id", "day"])
negative_ratings = negative_ratings.drop(labels=["rating", "interaction", "timestamp"], axis=1)
negative_ratings.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,movie_id
user_id,day,Unnamed: 2_level_1
1,4146,303
1,4146,879
1,4146,1148
1,4146,1187
1,4146,1228


In [11]:
# Group by user id and day to form lists, then sample one negative per session

In [13]:
negative_sessions = negative_ratings.reset_index()
negative_sessions = negative_sessions.groupby(["user_id", "day"]).agg({"movie_id": ["collect","count"]})
negative_sessions.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,movie_id,movie_id
Unnamed: 0_level_1,Unnamed: 1_level_1,collect,count
user_id,day,Unnamed: 2_level_2,Unnamed: 3_level_2
1,4146,"[303, 879, 1148, 1187, 1228, 1923, 1924, 1980,...",31
2,4071,"[1, 62, 259, 264, 376, 476, 520, 548, 581, 643...",71
3,7521,"[172, 438, 476, 765, 1100, 1169, 1238, 1287, 1...",154
3,7688,"[10054, 10168, 10450, 10784, 11114, 12353, 124...",8
3,8045,"[3236, 7847, 12544, 12615, 13813, 13937, 14034...",33


In [14]:
sampled_indices = np.array([np.random.randint(0,count) if count > 0 else -1 for count in negative_sessions[("movie_id", "count")].to_pandas()], dtype=np.int32)

In [15]:
sampled_items = np.array([items[index] if index >= 0 else 0 for index, items in zip(sampled_indices, negative_sessions[("movie_id", "collect")].to_pandas())], dtype=np.int32)

In [16]:
negative_sessions.columns = ["movie_ids", "movie_id_count"]
negative_sessions["movie_id"] = sampled_items
# negative_sessions = negative_sessions.drop(labels=[("target_item", "count"), ("target_item", "collect")], axis=1)
# negative_sessions.columns = ["target_item"]
negative_sessions.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,movie_ids,movie_id_count,movie_id
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,4146,"[303, 879, 1148, 1187, 1228, 1923, 1924, 1980,...",31,8016
2,4071,"[1, 62, 259, 264, 376, 476, 520, 548, 581, 643...",71,4431
3,7521,"[172, 438, 476, 765, 1100, 1169, 1238, 1287, 1...",154,4284
3,7688,"[10054, 10168, 10450, 10784, 11114, 12353, 124...",8,10784
3,8045,"[3236, 7847, 12544, 12615, 13813, 13937, 14034...",33,19837


In [17]:
negative_targets = negative_sessions.drop(labels=["movie_ids", "movie_id_count"], axis=1)
negative_targets.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,movie_id
user_id,day,Unnamed: 2_level_1
1,4146,8016
2,4071,4431
3,7521,4284
3,7688,10784
3,8045,19837


In [18]:
both_targets = positive_interactions.join(negative_targets, how="left", lsuffix="_pos", rsuffix="_neg")
both_targets

Unnamed: 0_level_0,Unnamed: 1_level_0,movie_id_pos,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,movie_id_neg
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
6167,4054,1234,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43,"[447, 698, 962, 965, 2212, 2817, 2827, 2832, 2...","[9, 20]",91,1259
6170,875,588,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82,"[361, 547, 640, 685, 1434, 2175, 2177, 2817, 3...","[4, 5, 10, 14, 16, 13]",82,747
6169,5792,6652,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9,"[3545, 6022, 6224, 6540, 6541, 6606, 6755, 124...","[10, 12, 17, 18]",71,883
6172,9019,20580,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24,"[28225, 30747, 66047]","[6, 8, 9]",3,11275
6174,2176,2055,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41,"[431, 703, 709, 721, 2817, 4033, 13340, 13406,...","[6, 16]",50,3416
...,...,...,...,...,...,...,...,...,...,...,...
162533,6248,1259,"[8477, 0]","[6, 9, 16, 4, 5]","[1329514139, 1329514173]","[16699, 16208]",2,"[389, 685, 1136, 1266, 1696, 1964, 2701, 2817,...","[2, 3]",102,16830
162533,5875,15381,[68186],[15],[1297289876],[15464],1,"[17023, 31356, 33288, 33627, 34209, 34429, 435...","[6, 9]",12,
162530,2319,537,"[8770, 7556, 36137, 40642, 56019, 22246, 35740...","[7, 9, 2, 3, 20, 2, 3, 6, 10, 16, 2, 9, 19, 2,...","[990018455, 990018600, 990018600, 990018600, 9...","[841, 1172, 1168, 109, 1208, 1180, 3430, 719, ...",9,"[270, 431, 451, 496, 685, 721, 730, 969, 1964,...","[2, 17, 18]",231,
162521,6134,16083,[29603],"[3, 4, 5, 6, 13]",[1319648967],[11209],1,"[48635, 69170, 71167, 72716]","[6, 9]",4,


In [19]:
positive_examples = both_targets[~both_targets["movie_id_neg"].isna()]
positive_examples["movie_id"] = positive_examples["movie_id_pos"]
positive_examples = positive_examples.drop(labels=["movie_id_pos", "movie_id_neg"], axis=1)
positive_examples["label"] = 1
positive_examples

Unnamed: 0_level_0,Unnamed: 1_level_0,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,movie_id,label
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
6167,4054,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43,"[447, 698, 962, 965, 2212, 2817, 2827, 2832, 2...","[9, 20]",91,1234,1
6170,875,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82,"[361, 547, 640, 685, 1434, 2175, 2177, 2817, 3...","[4, 5, 10, 14, 16, 13]",82,588,1
6169,5792,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9,"[3545, 6022, 6224, 6540, 6541, 6606, 6755, 124...","[10, 12, 17, 18]",71,6652,1
6172,9019,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24,"[28225, 30747, 66047]","[6, 8, 9]",3,20580,1
6174,2176,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41,"[431, 703, 709, 721, 2817, 4033, 13340, 13406,...","[6, 16]",50,2055,1
...,...,...,...,...,...,...,...,...,...,...,...
162533,5685,"[31871, 61139, 42965, 45010, 63891, 72645, 201...","[2, 9, 19, 9, 18, 9, 3, 9, 19, 7, 9, 9, 16, 2,...","[1280832291, 1280832341, 1280832346, 128083264...","[2853, 477, 1157, 1561, 2239, 892, 5905, 735, ...",62,"[233, 591, 1624, 1626, 3607, 6022, 6224, 7021,...","[9, 16]",83,5118,1
162533,5879,"[34063, 29481]","[8, 3, 4, 5, 6, 10]","[1297630922, 1297631112]","[15519, 4781]",2,"[4780, 5253, 46621, 56296, 58326, 60478]",[9],6,10091,1
162533,5692,"[59896, 56924, 32286, 67014, 17779, 22036, 354...","[9, 19, 8, 6, 9, 16, 19, 2, 3, 9, 10, 4, 5, 10...","[1281405901, 1281405922, 1281405928, 128140595...","[14804, 9777, 352, 7029, 588, 10001, 3893, 583...",107,"[2817, 4338, 4448, 13252, 18319, 28833, 29481,...","[3, 4, 5, 6, 10, 17]",38,662,1
162533,6248,"[8477, 0]","[6, 9, 16, 4, 5]","[1329514139, 1329514173]","[16699, 16208]",2,"[389, 685, 1136, 1266, 1696, 1964, 2701, 2817,...","[2, 3]",102,1259,1


In [20]:
negative_interactions = both_targets[~both_targets["movie_id_neg"].isna()]
negative_interactions["movie_id"] = negative_interactions["movie_id_neg"]
negative_interactions = negative_interactions.drop(labels=["movie_id_pos", "movie_id_neg", "movie_tags_unique", "movie_tags_nunique", "movie_genres"], axis=1)
negative_interactions.reset_index(inplace=True)
negative_interactions = negative_interactions.set_index("movie_id")
negative_interactions

Unnamed: 0_level_0,user_id,day,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count
movie_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1259,6167,4054,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43
747,6170,875,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82
883,6169,5792,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9
11275,6172,9019,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24
3416,6174,2176,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41
...,...,...,...,...,...,...,...
1532,162533,5685,"[31871, 61139, 42965, 45010, 63891, 72645, 201...","[2, 9, 19, 9, 18, 9, 3, 9, 19, 7, 9, 9, 16, 2,...","[1280832291, 1280832341, 1280832346, 128083264...","[2853, 477, 1157, 1561, 2239, 892, 5905, 735, ...",62
15523,162533,5879,"[34063, 29481]","[8, 3, 4, 5, 6, 10]","[1297630922, 1297631112]","[15519, 4781]",2
1918,162533,5692,"[59896, 56924, 32286, 67014, 17779, 22036, 354...","[9, 19, 8, 6, 9, 16, 19, 2, 3, 9, 10, 4, 5, 10...","[1281405901, 1281405922, 1281405928, 128140595...","[14804, 9777, 352, 7029, 588, 10001, 3893, 583...",107
16830,162533,6248,"[8477, 0]","[6, 9, 16, 4, 5]","[1329514139, 1329514173]","[16699, 16208]",2


In [21]:
negative_examples = negative_interactions.join(movie_features, how="left", lsuffix="", rsuffix="_movie")
negative_examples.reset_index(inplace=True)
negative_examples = negative_examples.set_index(["user_id", "day"])
negative_examples["label"] = 0
negative_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,movie_id,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,label
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
1499,3357,338,"[55756, 37891, 71389, 44760, 19053, 51376, 182...","[2, 9, 19, 6, 7, 16, 2, 6, 19, 9, 9, 17, 18, 6...","[1079735246, 1079735256, 1079735260, 107973530...","[1202, 1453, 462, 1324, 1592, 467, 2305, 1615,...",46,"[439, 752, 857, 2696, 2697, 2816, 3647, 4541, ...",[6],64,0
1503,7831,11263,[0],[9],[1466262866],[31326],1,"[373, 2325, 2835, 5078, 6204, 8084, 13338, 177...","[9, 16]",48,0
1500,8856,18593,"[52321, 30146, 2835, 51881, 62140, 34324, 2822...","[9, 19, 6, 9, 16, 19, 9, 16, 6, 16, 9, 9, 16, ...","[1554779513, 1554779548, 1554779560, 155478015...","[523, 352, 17, 22781, 1214, 892, 20580, 900, 6...",16,"[2816, 6578]",[8],2,0
3760,7390,3756,"[4159, 22957, 52589, 66942, 35459, 4541, 61971...","[2, 12, 17, 7, 9, 18, 9, 18, 2, 9, 19, 2, 3, 7...","[1428178639, 1428178649, 1428178662, 142817867...","[6380, 5358, 4132, 1178, 5905, 1176, 13941, 28...",177,"[6222, 6224, 13509, 13974, 16922, 32589, 32875...",[6],22,0
1503,7637,14472,[0],[1],[1449523965],[32629],1,"[1801, 2650, 6223, 12710, 13866, 14691, 15125,...","[9, 18]",31,0


In [22]:
training_examples = cudf.concat([positive_examples, negative_examples])
training_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,movie_id,label
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
6167,4054,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43,"[447, 698, 962, 965, 2212, 2817, 2827, 2832, 2...","[9, 20]",91,1234,1
6170,875,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82,"[361, 547, 640, 685, 1434, 2175, 2177, 2817, 3...","[4, 5, 10, 14, 16, 13]",82,588,1
6169,5792,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9,"[3545, 6022, 6224, 6540, 6541, 6606, 6755, 124...","[10, 12, 17, 18]",71,6652,1
6172,9019,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24,"[28225, 30747, 66047]","[6, 8, 9]",3,20580,1
6174,2176,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41,"[431, 703, 709, 721, 2817, 4033, 13340, 13406,...","[6, 16]",50,2055,1


In [23]:
shuffled_examples = training_examples.reset_index().iloc[np.random.permutation(len(training_examples))]
shuffled_examples

Unnamed: 0,user_id,day,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,movie_id,label
55102,25690,2383,"[40621, 64668]","[9, 6, 9]","[995542104, 995542222]","[3314, 3471]",2,"[1423, 2817, 3607, 4050, 5274, 6224, 8530, 123...","[9, 15, 18]",98,2621,1
673186,85254,2855,"[40952, 816, 57559, 48023, 52666, 18291, 2816,...","[6, 16, 4, 9, 10, 9, 2, 3, 6, 9, 16, 19, 6, 19...","[1036287980, 1036288019, 1036288083, 103628815...","[5192, 4768, 5723, 5685, 5688, 5690, 5691, 5634]",8,"[457, 472, 3300, 3428, 4494, 11716, 13596, 184...","[7, 9]",38,4274,0
116216,45463,4900,[60889],"[9, 18]",[1212986201],[3849],1,"[19263, 22052, 23714, 34063, 35238, 37615, 391...","[8, 14]",21,12353,1
604409,55998,7937,[45296],"[7, 9]",[1475418791],[315],1,"[256, 2817, 3281, 3284, 3965, 5048, 6191, 9537...","[9, 16]",119,1641,0
11111,464,7627,"[60020, 31525, 968, 2435, 57898, 52589, 36831,...","[3, 10, 3, 10, 2, 3, 17, 9, 11, 2, 3, 3, 6, 9,...","[1448668575, 1448668634, 1448668636, 144866863...","[4888, 5841, 258, 13358, 1169, 16830, 11446, 1...",15,"[379, 389, 685, 721, 1136, 1266, 1916, 1964, 2...","[2, 3, 10]",92,2026,1
...,...,...,...,...,...,...,...,...,...,...,...,...
172393,62406,7548,"[41523, 439, 45113, 39008, 10050, 50795, 45036...","[6, 9, 16, 19, 7, 12, 18, 2, 3, 17, 18, 3, 4, ...","[1441823132, 1441823135, 1441823141, 144182314...","[352, 586, 476, 1, 2481, 1238, 315, 841, 523, ...",60,"[352, 1497, 1747, 1858, 2254, 2338, 2816, 2817...","[2, 3, 20]",115,1172,1
847447,151009,4662,"[50982, 47665]","[9, 10, 15, 18, 12]","[1192473784, 1192473789]","[11475, 11546]",2,"[7493, 8784, 26721, 31358, 32424, 34442, 36831...",[20],12,11446,0
744455,107494,5103,"[42086, 20525, 61707, 38580, 55852, 46124, 661...","[6, 9, 6, 7, 18, 6, 3, 7, 9, 18, 9, 15, 17, 18...","[1230508800, 1230508901, 1230508907, 123050891...","[11925, 3908, 11116, 11486, 4773, 4868, 11123,...",29,"[3069, 3108, 3219, 5693, 6223, 10080, 10226, 1...","[2, 6, 17, 20]",37,2610,0
310808,115952,8904,"[29481, 36137]","[2, 3, 4, 5, 6, 2, 6]","[1558940428, 1558940813]","[8247, 53633]",2,"[32292, 42833]",[9],2,24549,1


In [24]:
shuffled_examples.to_parquet(os.path.join(INPUT_DATA_DIR, "ranking_training.parquet"))

In [25]:
rmm.reinitialize(managed_memory=False)