In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [2]:
import os

import cudf
import numpy as np
import rmm

import nvtabular as nvt

In [3]:
rmm.reinitialize(managed_memory=True)

In [4]:
INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("./data/")
)

# Read the Retrieval Training Examples

In [5]:
retrieval_training = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "retrieval_training.parquet"))
retrieval_training.columns = [f"user_{c}" if not c == "target_item" else "target_item" for c in retrieval_training.columns]
retrieval_training.reset_index(inplace=True)
retrieval_training = retrieval_training.set_index("target_item")
retrieval_training.head()

Unnamed: 0_level_0,user_id,day,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count
target_item,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
7237,1,4146,"[40280, 34864, 60921, 55756, 53948, 968, 2649,...","[3, 10, 9, 17, 18, 3, 9, 19, 3, 4, 5, 6, 6, 9,...","[1147868053, 1147868097, 1147868414, 114786846...","[5841, 1592, 1218, 6259, 3353, 1062, 6589, 384...",38
2061,2,4071,"[70987, 21602, 55885, 217, 26361, 361, 38094, ...","[3, 10, 6, 16, 2, 3, 17, 18, 9, 17, 18, 2, 3, ...","[1141415528, 1141415566, 1141415576, 114141558...","[5841, 493, 1339, 1592, 2550, 150, 234, 4781, ...",112
24542,3,7521,"[63271, 34088, 581, 28225, 491, 34656, 41947, ...","[6, 9, 16, 19, 7, 12, 18, 3, 4, 5, 6, 10, 2, 1...","[1439472199, 1439472203, 1439472215, 143947222...","[352, 586, 1, 2481, 258, 315, 1167, 523, 12217...",247
4240,3,7688,"[69734, 68604, 71594, 56924, 32424, 31525, 668...","[7, 9, 17, 18, 2, 9, 19, 6, 16, 8, 20, 6, 10, ...","[1453904021, 1453904031, 1453904046, 145390404...","[1176, 1178, 10678, 9777, 11446, 11930, 10407,...",17
9335,3,8045,"[38644, 43776, 43440, 57466, 45557, 13968, 170...","[7, 15, 18, 2, 18, 6, 7, 18, 4, 17, 6, 7, 9, 7...","[1484753654, 1484753762, 1484753766, 148475380...","[1063, 29365, 3908, 726, 763, 110, 213, 29375,...",30


In [6]:
# Join movie features on to the positive examples from the retrieval training data

In [7]:
movie_features = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "movie_features.parquet"))
movie_features["target_item"] = movie_features["movie_id"]
movie_features.reset_index(inplace=True)
movie_features = movie_features.set_index("target_item")
movie_features = movie_features.drop(labels=["movie_id", "index", "datetime", "created"], axis=1)
movie_features.columns = [f"movie_{c}" for c in movie_features.columns]
movie_features.head()

Unnamed: 0_level_0,movie_tags_unique,movie_genres,movie_tags_nunique
target_item,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
27265,"[40345, 59079]",[9],2
27273,[28414],[8],1
27266,[32292],[9],1
27282,"[3365, 33048, 43053, 46467, 50807, 51310, 5739...","[9, 18]",10
27290,"[31336, 34155, 42699, 48919, 48957, 51656, 602...",[1],8


In [8]:
positive_interactions = retrieval_training.join(movie_features, lsuffix="_user", rsuffix="_movie")
positive_interactions.reset_index(inplace=True)
positive_interactions = positive_interactions.set_index(["user_id", "day"])
positive_interactions

Unnamed: 0_level_0,Unnamed: 1_level_0,target_item,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
7981,6136,2868,"[547, 28737, 57284, 35927, 13432, 37040, 42445...","[4, 5, 6, 14, 6, 9, 16, 3, 9, 8, 9, 17, 18, 2,...","[1319770808, 1319803451, 1319806247, 131980631...","[11081, 16423, 1227, 11360, 12425, 15525, 1335...",16,"[613, 752, 2817, 3046, 3219, 4004, 4541, 4579,...","[2, 7, 9, 18]",258
7985,2305,548,"[66886, 6858, 49266]","[6, 12, 15, 18, 2, 7, 18, 6, 9]","[988772309, 988772583, 988773288]","[1370, 184, 512]",3,"[2463, 3607, 3661, 3696, 4448, 6324, 6329, 751...","[2, 9, 20]",57
7982,4118,1048,"[32698, 56636, 60016, 11298, 3101, 39888, 2550...","[4, 6, 10, 14, 16, 5, 6, 2, 7, 9, 5, 9, 10, 16...","[1145478723, 1145478750, 1145478818, 114547885...","[10249, 10760, 2105, 933, 7743, 3912, 1237, 10...",14,"[640, 687, 2254, 2817, 3545, 5449, 8371, 8912,...","[5, 6, 10, 14]",88
7990,7542,25989,"[18859, 23513, 40303, 32042, 54399, 6949, 5258...","[7, 9, 9, 19, 9, 16, 6, 9, 16, 19, 6, 8, 2, 7,...","[1441270874, 1441270877, 1441270881, 144127088...","[315, 523, 1641, 352, 29271, 21717, 21106, 287...",9,"[34450, 72650]",[20],2
7981,6140,11884,"[46124, 16299, 61562, 44182]","[3, 4, 10, 9, 9, 16, 9]","[1320189757, 1320189774, 1320189776, 1320189780]","[5509, 1174, 12924, 15524]",4,"[851, 1617, 12688, 28769, 47665]","[6, 9, 16]",5
...,...,...,...,...,...,...,...,...,...,...
162132,4891,2439,"[68408, 40789, 30064, 37077, 64974, 62738, 596...","[6, 14, 5, 6, 10, 14, 9, 6, 7, 9, 15, 18, 3, 4...","[1212259936, 1212260239, 1212260557, 121226076...","[9951, 1048, 12320, 11730, 4202, 899, 546, 7262]",8,"[39, 968, 969, 2817, 3219, 4972, 5481, 6223, 7...","[2, 9, 17]",84
162119,2403,251,"[45068, 20553, 3951, 45787, 59966, 66886, 6694...","[3, 4, 5, 6, 10, 16, 2, 3, 4, 9, 10, 9, 18, 19...","[997306638, 997306834, 997306943, 997306974, 9...","[4202, 2909, 3930, 3893, 4263, 2071, 3931, 700...",38,"[664, 675, 2245, 2370, 4004, 4159, 4448, 4541,...","[9, 12]",80
162128,6062,13857,"[51881, 66853, 65103, 10050, 41047, 43060]","[9, 9, 2, 18, 2, 6, 10, 13, 9, 10, 16, 7, 9, 1...","[1313385434, 1313385530, 1313385534, 131338554...","[15752, 11586, 14972, 10944, 7023, 11620]",6,"[2816, 4792, 35716, 40173]","[9, 16, 19]",4
162125,8645,7237,"[64520, 59142, 32021, 2816, 31783, 45042, 6484...","[9, 12, 12, 6, 16, 9, 12, 18, 3, 10, 3, 4, 5, ...","[1536609214, 1536609436, 1536609507, 153660974...","[46060, 1226, 4868, 6851, 4888, 4202, 6752, 47...",46,"[49, 2859, 3483, 4541, 4947, 6191, 6224, 6542,...","[9, 16, 17]",193


In [9]:
# TODO: Drop the target_item column and join on the negative examples from earlier

In [10]:
negative_ratings = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, "negative_ratings.parquet"))
negative_ratings["negative_item"] = negative_ratings["movie_id"]
negative_ratings = negative_ratings.set_index(["user_id", "day"])
negative_ratings = negative_ratings.drop(labels=["movie_id", "rating", "interaction", "timestamp"], axis=1)
negative_ratings.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,negative_item
user_id,day,Unnamed: 2_level_1
1,4146,303
1,4146,879
1,4146,1148
1,4146,1187
1,4146,1228


In [11]:
# Group by user id and day to form lists, then sample one negative per session

In [12]:
negative_sessions = negative_ratings.reset_index()
negative_sessions = negative_sessions.groupby(["user_id", "day"]).agg({"negative_item": ["collect","count"]})
negative_sessions.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,negative_item,negative_item
Unnamed: 0_level_1,Unnamed: 1_level_1,collect,count
user_id,day,Unnamed: 2_level_2,Unnamed: 3_level_2
1,4146,"[303, 879, 1148, 1187, 1228, 1923, 1924, 1980,...",31
2,4071,"[1, 62, 259, 264, 376, 476, 520, 548, 581, 643...",71
3,7521,"[172, 438, 476, 765, 1100, 1169, 1238, 1287, 1...",154
3,7688,"[10054, 10168, 10450, 10784, 11114, 12353, 124...",8
3,8045,"[3236, 7847, 12544, 12615, 13813, 13937, 14034...",33


In [13]:
sampled_indices = np.array([np.random.randint(0,count) if count > 0 else -1 for count in negative_sessions[("negative_item", "count")].to_pandas()], dtype=np.int32)

In [14]:
sampled_items = np.array([items[index] if index >= 0 else 0 for index, items in zip(sampled_indices, negative_sessions[("negative_item", "collect")].to_pandas())], dtype=np.int32)

In [15]:
negative_sessions.columns = ["negative_items", "negative_items_count"]
negative_sessions["target_item"] = sampled_items
# negative_sessions = negative_sessions.drop(labels=[("target_item", "count"), ("target_item", "collect")], axis=1)
# negative_sessions.columns = ["target_item"]
negative_sessions.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,negative_items,negative_items_count,target_item
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,4146,"[303, 879, 1148, 1187, 1228, 1923, 1924, 1980,...",31,5162
2,4071,"[1, 62, 259, 264, 376, 476, 520, 548, 581, 643...",71,3920
3,7521,"[172, 438, 476, 765, 1100, 1169, 1238, 1287, 1...",154,4262
3,7688,"[10054, 10168, 10450, 10784, 11114, 12353, 124...",8,11114
3,8045,"[3236, 7847, 12544, 12615, 13813, 13937, 14034...",33,16575


In [16]:
negative_targets = negative_sessions.drop(labels=["negative_items", "negative_items_count"], axis=1)
negative_targets.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,target_item
user_id,day,Unnamed: 2_level_1
1,4146,5162
2,4071,3920
3,7521,4262
3,7688,11114
3,8045,16575


In [17]:
both_targets = positive_interactions.join(negative_targets, how="left", lsuffix="_pos", rsuffix="_neg")
both_targets

Unnamed: 0_level_0,Unnamed: 1_level_0,target_item_pos,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,target_item_neg
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
6167,4054,1234,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43,"[447, 698, 962, 965, 2212, 2817, 2827, 2832, 2...","[9, 20]",91,2767
6170,875,588,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82,"[361, 547, 640, 685, 1434, 2175, 2177, 2817, 3...","[4, 5, 10, 14, 16, 13]",82,801
6169,5792,6652,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9,"[3545, 6022, 6224, 6540, 6541, 6606, 6755, 124...","[10, 12, 17, 18]",71,787
6172,9019,20580,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24,"[28225, 30747, 66047]","[6, 8, 9]",3,11275
6174,2176,2055,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41,"[431, 703, 709, 721, 2817, 4033, 13340, 13406,...","[6, 16]",50,3457
...,...,...,...,...,...,...,...,...,...,...,...
162533,6248,1259,"[8477, 0]","[6, 9, 16, 4, 5]","[1329514139, 1329514173]","[16699, 16208]",2,"[389, 685, 1136, 1266, 1696, 1964, 2701, 2817,...","[2, 3]",102,17218
162533,5875,15381,[68186],[15],[1297289876],[15464],1,"[17023, 31356, 33288, 33627, 34209, 34429, 435...","[6, 9]",12,
162530,2319,537,"[8770, 7556, 36137, 40642, 56019, 22246, 35740...","[7, 9, 2, 3, 20, 2, 3, 6, 10, 16, 2, 9, 19, 2,...","[990018455, 990018600, 990018600, 990018600, 9...","[841, 1172, 1168, 109, 1208, 1180, 3430, 719, ...",9,"[270, 431, 451, 496, 685, 721, 730, 969, 1964,...","[2, 17, 18]",231,
162521,6134,16083,[29603],"[3, 4, 5, 6, 13]",[1319648967],[11209],1,"[48635, 69170, 71167, 72716]","[6, 9]",4,


In [18]:
positive_examples = both_targets[~both_targets["target_item_neg"].isna()]
positive_examples["target_item"] = positive_examples["target_item_pos"]
positive_examples = positive_examples.drop(labels=["target_item_pos", "target_item_neg"], axis=1)
positive_examples["label"] = 1
positive_examples

Unnamed: 0_level_0,Unnamed: 1_level_0,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,target_item,label
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
6167,4054,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43,"[447, 698, 962, 965, 2212, 2817, 2827, 2832, 2...","[9, 20]",91,1234,1
6170,875,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82,"[361, 547, 640, 685, 1434, 2175, 2177, 2817, 3...","[4, 5, 10, 14, 16, 13]",82,588,1
6169,5792,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9,"[3545, 6022, 6224, 6540, 6541, 6606, 6755, 124...","[10, 12, 17, 18]",71,6652,1
6172,9019,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24,"[28225, 30747, 66047]","[6, 8, 9]",3,20580,1
6174,2176,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41,"[431, 703, 709, 721, 2817, 4033, 13340, 13406,...","[6, 16]",50,2055,1
...,...,...,...,...,...,...,...,...,...,...,...
162533,5685,"[31871, 61139, 42965, 45010, 63891, 72645, 201...","[2, 9, 19, 9, 18, 9, 3, 9, 19, 7, 9, 9, 16, 2,...","[1280832291, 1280832341, 1280832346, 128083264...","[2853, 477, 1157, 1561, 2239, 892, 5905, 735, ...",62,"[233, 591, 1624, 1626, 3607, 6022, 6224, 7021,...","[9, 16]",83,5118,1
162533,5879,"[34063, 29481]","[8, 3, 4, 5, 6, 10]","[1297630922, 1297631112]","[15519, 4781]",2,"[4780, 5253, 46621, 56296, 58326, 60478]",[9],6,10091,1
162533,5692,"[59896, 56924, 32286, 67014, 17779, 22036, 354...","[9, 19, 8, 6, 9, 16, 19, 2, 3, 9, 10, 4, 5, 10...","[1281405901, 1281405922, 1281405928, 128140595...","[14804, 9777, 352, 7029, 588, 10001, 3893, 583...",107,"[2817, 4338, 4448, 13252, 18319, 28833, 29481,...","[3, 4, 5, 6, 10, 17]",38,662,1
162533,6248,"[8477, 0]","[6, 9, 16, 4, 5]","[1329514139, 1329514173]","[16699, 16208]",2,"[389, 685, 1136, 1266, 1696, 1964, 2701, 2817,...","[2, 3]",102,1259,1


In [19]:
negative_interactions = both_targets[~both_targets["target_item_neg"].isna()]
negative_interactions["target_item"] = negative_interactions["target_item_neg"]
negative_interactions = negative_interactions.drop(labels=["target_item_pos", "target_item_neg", "movie_tags_unique", "movie_tags_nunique", "movie_genres"], axis=1)
negative_interactions.reset_index(inplace=True)
negative_interactions = negative_interactions.set_index("target_item")
negative_interactions

Unnamed: 0_level_0,user_id,day,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count
target_item,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2767,6167,4054,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43
801,6170,875,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82
787,6169,5792,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9
11275,6172,9019,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24
3457,6174,2176,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41
...,...,...,...,...,...,...,...
6380,162533,5685,"[31871, 61139, 42965, 45010, 63891, 72645, 201...","[2, 9, 19, 9, 18, 9, 3, 9, 19, 7, 9, 9, 16, 2,...","[1280832291, 1280832341, 1280832346, 128083264...","[2853, 477, 1157, 1561, 2239, 892, 5905, 735, ...",62
15523,162533,5879,"[34063, 29481]","[8, 3, 4, 5, 6, 10]","[1297630922, 1297631112]","[15519, 4781]",2
10834,162533,5692,"[59896, 56924, 32286, 67014, 17779, 22036, 354...","[9, 19, 8, 6, 9, 16, 19, 2, 3, 9, 10, 4, 5, 10...","[1281405901, 1281405922, 1281405928, 128140595...","[14804, 9777, 352, 7029, 588, 10001, 3893, 583...",107
17218,162533,6248,"[8477, 0]","[6, 9, 16, 4, 5]","[1329514139, 1329514173]","[16699, 16208]",2


In [20]:
negative_examples = negative_interactions.join(movie_features, how="left", lsuffix="", rsuffix="_movie")
negative_examples.reset_index(inplace=True)
negative_examples = negative_examples.set_index(["user_id", "day"])
negative_examples["label"] = 0
negative_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,target_item,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,label
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
9710,7745,10001,"[18451, 40759, 60381, 59588, 44776, 37047, 535...","[7, 9, 15, 18, 6, 9, 17, 9, 16, 17, 4, 5, 9, 1...","[1458829927, 1458830112, 1458830116, 145883017...","[1935, 1621, 7237, 5860, 2909, 11730, 35337, 1...",23,"[2816, 5253, 6758, 15013, 16718, 18814, 18816,...","[9, 16]",22,0
8656,2659,1342,"[7831, 31587, 55733, 23798, 33064, 12797, 1878...","[6, 6, 2, 3, 6, 3, 6, 17, 20, 6, 7, 6, 16, 6, ...","[1019401721, 1019401721, 1019401886, 101940192...","[340, 19, 2592, 1924, 1667, 2176, 2317, 160, 1...",43,"[3148, 3149, 3150, 3607, 4448, 4786, 5338, 570...","[2, 7]",73,0
8668,3871,2400,"[51461, 44535, 18814, 7474, 21046, 58021, 4471...","[7, 9, 2, 3, 9, 19, 9, 4, 5, 6, 16, 3, 9, 16, ...","[1124136217, 1124136223, 1124136240, 112413625...","[3055, 7019, 6443, 1992, 6862, 8084, 2596, 773...",13,"[30, 3641, 4117, 6635, 6757, 7108, 10554, 1225...","[2, 18]",48,0
8662,5839,6517,"[863, 23181, 43579, 56791, 54448, 28185, 40368...","[6, 2, 3, 5, 6, 10, 6, 9, 14, 16, 6, 7, 6, 9, ...","[1294158662, 1294158691, 1294158697, 129415879...","[1253, 1917, 894, 1098, 889, 5680, 14804, 3005...",16,"[408, 969, 1416, 2673, 2674, 2816, 3219, 3221,...","[9, 18]",34,0
8656,2049,2254,"[26716, 10658, 14664, 42212, 29107, 60081, 622...","[2, 3, 17, 7, 9, 6, 10, 6, 20, 2, 19, 2, 9, 19...","[966653716, 966653779, 966653779, 966654026, 9...","[1340, 1191, 3370, 3771, 3045, 1202, 1931, 194...",165,"[2017, 2066, 3545, 4448, 7090, 8475, 13577, 14...","[2, 3, 9, 18]",28,0


In [21]:
training_examples = cudf.concat([positive_examples, negative_examples])
training_examples.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,target_item,label
user_id,day,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
6167,4054,"[526, 68122, 32890, 35728, 16493, 50675, 42700...","[6, 9, 16, 2, 7, 9, 18, 7, 9, 6, 16, 9, 17, 3,...","[1139957672, 1139957684, 1139957687, 113995773...","[11, 2868, 1183, 1274, 1528, 2896, 5311, 3022,...",43,"[447, 698, 962, 965, 2212, 2817, 2827, 2832, 2...","[9, 20]",91,1234,1
6170,875,"[8366, 43579, 38914, 5693, 41682, 12393, 4448,...","[3, 4, 5, 6, 10, 5, 6, 10, 14, 2, 3, 18, 6, 2,...","[865250838, 865250882, 865250883, 865250883, 8...","[1, 1048, 490, 5, 372, 773, 1490, 1358, 86, 83...",82,"[361, 547, 640, 685, 1434, 2175, 2177, 2817, 3...","[4, 5, 10, 14, 16, 13]",82,588,1
6169,5792,"[22562, 37450, 67052, 30565, 59560, 62731, 609...","[6, 10, 6, 12, 14, 17, 4, 5, 10, 14, 5, 6, 3, ...","[1290117144, 1290117158, 1290117184, 129011718...","[1246, 2566, 1250, 2713, 2072, 3809, 2655, 199...",9,"[3545, 6022, 6224, 6540, 6541, 6606, 6755, 124...","[10, 12, 17, 18]",71,6652,1
6172,9019,"[2817, 39421, 22946, 46568, 32052, 50054, 2336...","[7, 9, 2, 7, 9, 18, 3, 4, 5, 6, 10, 16, 2, 3, ...","[1568910192, 1568910198, 1568910227, 156891023...","[315, 2868, 4202, 3480, 1940, 4888, 6752, 7237...",24,"[28225, 30747, 66047]","[6, 8, 9]",3,20580,1
6174,2176,"[33834, 32342, 3219, 61085, 9126, 57269, 2817,...","[9, 16, 4, 5, 6, 2, 3, 9, 6, 12, 9, 18, 2, 9, ...","[977644832, 977644971, 977644971, 977645013, 9...","[1656, 3651, 3480, 3685, 3315, 3653, 2713, 315...",41,"[431, 703, 709, 721, 2817, 4033, 13340, 13406,...","[6, 16]",50,2055,1


In [22]:
shuffled_examples = training_examples.reset_index().iloc[np.random.permutation(len(training_examples))]
shuffled_examples

Unnamed: 0,user_id,day,user_search_terms,user_genres,user_timestamps,user_movie_ids,user_movie_id_count,movie_tags_unique,movie_genres,movie_tags_nunique,target_item,label
125596,49044,8951,[55603],[9],[1562976653],[18718],1,"[35180, 50444, 53677, 57311, 66230, 71291]","[2, 9, 19]",6,23015,1
755480,115668,1737,"[965, 48689, 8100, 48496, 69957, 31423, 52524,...","[9, 15, 16, 18, 2, 3, 6, 10, 16, 2, 3, 7, 9, 7...","[939718393, 939718701, 939719066, 939719066, 9...","[883, 1168, 1169, 841, 50, 1053, 1256, 586, 23...",24,"[408, 957, 965, 967, 2816, 2827, 3221, 3695, 5...","[6, 9, 16]",86,889,0
681592,90558,7761,"[68303, 1395, 425, 55205, 30105, 30791, 30064,...","[7, 9, 7, 9, 2, 3, 17, 6, 14, 16, 2, 3, 17, 9,...","[1460208914, 1460208926, 1460208931, 146020893...","[315, 841, 258, 12217, 1167, 7237, 10001, 1067...",63,"[135, 270, 613, 685, 697, 888, 965, 968, 1469,...","[2, 9, 19]",153,1178,0
580860,47901,9046,"[66828, 2317, 44072, 22744, 26425, 25379]","[6, 19, 7, 12, 9, 7, 9, 17, 18, 2, 12, 17, 18,...","[1571252405, 1571252537, 1571252551, 157125256...","[735, 1189, 2223, 1176, 2198, 2461]",6,"[556, 2915, 3121, 4004, 4171, 4266, 4268, 5184...","[15, 17, 18]",126,32,0
502546,20918,6423,"[43458, 48137, 71514, 57913, 31257, 30848, 281...","[20, 7, 15, 16, 18, 9, 17, 18, 2, 7, 15, 17, 1...","[1344628159, 1344628161, 1344628163, 134462816...","[13281, 11162, 12425, 5338, 10112, 6222, 13262...",8,"[18446, 30023, 30064, 32292, 55598, 56728, 670...",[9],8,12320,0
...,...,...,...,...,...,...,...,...,...,...,...,...
832054,139705,584,"[37873, 685, 66884, 547, 67574, 43316, 40665, ...","[6, 7, 9, 18, 2, 7, 18, 7, 12, 18, 4, 5, 10, 1...","[840110079, 840110079, 840110179, 840110179, 8...","[293, 585, 586, 588, 160, 315, 289, 453, 476, ...",28,"[3575, 4004, 5273, 5303, 5693, 9206, 9760, 168...",[6],18,578,0
8799,1748,7460,[8311],"[3, 6, 9]",[1434230951],[26758],1,"[39888, 69870]",[6],2,13622,1
726302,97160,7246,[44739],"[2, 10, 17]",[1415732968],[20456],1,[58164],[8],1,18997,0
181834,63151,4788,"[47665, 34316, 32665, 55489, 69156, 25509]","[6, 9, 8, 6, 9, 6, 9]","[1203301899, 1203301933, 1203301941, 120330196...","[193, 161, 2020, 1069, 1410, 1157]",6,"[6172, 6224, 9432, 13291, 13334, 13736, 15720,...","[9, 15]",52,2595,1


In [23]:
shuffled_examples.to_parquet(os.path.join(INPUT_DATA_DIR, "ranking_training.parquet"))

In [24]:
rmm.reinitialize(managed_memory=False)