In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.join(os.path.abspath(''), '../'))

from models.graphsage_model import GraphSAGE
from pprint import pprint
from dataset.data_loader import DataLoader, playtime_forever_users_games_edge_scoring_function, GaussianNormalizer

from utils.utils import get_game_name_and_scores

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%reload_ext autoreload

In [3]:
# data_loader = DataLoader(num_users_to_load_per_snowball=100, users_games_edge_scoring_function=playtime_forever_users_games_edge_scoring_function,user_game_edge_embeddings=['playtime_forever'], cache_local_dataset=True, game_embeddings=['name'])
data_loader = DataLoader(users_games_edge_scoring_function=playtime_forever_users_games_edge_scoring_function, user_game_edge_embeddings=['playtime_forever'], users_games_edge_score_normalizers = [GaussianNormalizer(1.0, 1.0)], interactions_score_normalizers = [GaussianNormalizer(0.0, 1.0)], cache_local_dataset=True, game_embeddings=['numReviews', 'avgReviewScore', 'price', 'numFollowers', 'genres', 'tags', 'name'])
data_loader.load_random_edge_train_test_split(train_percentage=0.9, test_percentage=0.1, seed=0)
data_loader.save_data_loader_parameters('test_graphsage_data_loader', overwrite=True)

In [7]:
%reload_ext autoreload
save_file_name = 'test_graphsage_model'
model = GraphSAGE(hidden_channels=50, aggr='mean', save_file_name=save_file_name, nn_save_name='best', num_epochs=50, batch_percent=0.1, learning_rate=1e-3, weight_decay=1e-10, seed=12412)
model.set_data_loader(data_loader)
model.train(debug=False)
model.save(save_file_name, overwrite=True)

Total Learnable Parameters: 72801


Training: 100%|██████████| 50/50 [23:04<00:00, 27.69s/it]


In [8]:
display(get_game_name_and_scores(data_loader, model.score_and_predict_n_games_for_user(76561198835352289)[:10]))
display(get_game_name_and_scores(data_loader, model.score_and_predict_n_games_for_user(76561198835352289)[-10:]))

Unnamed: 0,id,name,score
0,26800,Braid,1.801402
1,1137350,Filament,1.772937
2,799890,REVENGER: Age of Morons,1.722109
3,241560,The Crew™,1.685062
4,337980,Vagrant Hearts,1.684349
5,405720,Perfect Universe - Play with Gravity,1.663582
6,266430,Anarchy Arcade,1.660855
7,688470,汉匈决战/Han Xiongnu Wars,1.660311
8,113020,Monaco: What's Yours Is Mine,1.650971
9,693830,Containment Corps,1.647817


Unnamed: 0,id,name,score
0,473810,Killbot,-19.119495
1,265770,Cannons Lasers Rockets,-19.135603
2,415860,Tactical Craft Online,-22.506048
3,345330,Eden Rising,-27.239796
4,1049800,BLOCKADE,-27.310291
5,548480,NightZ,-28.076128
6,1127460,Mod and Play,-33.590176
7,453270,Madness Cubed,-36.975391
8,466940,The Sandbox Evolution - Craft a 2D Pixel Unive...,-55.531349
9,1132530,Wolf Ridge,-56.241768


In [18]:
model2 = GraphSAGE()
model2.load('test_graphsage_model', load_published_model=False)
model2.set_data_loader(data_loader)
display(get_game_name_and_scores(data_loader, model2.score_and_predict_n_games_for_user(76561198835352289)[:10]))
display(get_game_name_and_scores(data_loader, model2.score_and_predict_n_games_for_user(76561198835352289)[-10:]))

Unnamed: 0,id,name,score
0,814380,Sekiro™: Shadows Die Twice - GOTY Edition,1.212403
1,374320,DARK SOULS™ III,1.156476
2,489830,The Elder Scrolls V: Skyrim Special Edition,1.155438
3,1113560,NieR Replicant™ ver.1.22474487139...,1.151785
4,524220,NieR:Automata™,1.150343
5,601150,Devil May Cry 5,1.143475
6,678960,CODE VEIN,1.142847
7,1091500,Cyberpunk 2077,1.139257
8,292030,The Witcher® 3: Wild Hunt,1.136536
9,367520,Hollow Knight,1.136029


Unnamed: 0,id,name,score
0,521200,Epic Battle Fantasy 3,1.01061
1,1568590,Goose Goose Duck,1.010098
2,1677740,Stumble Guys,1.009783
3,1184140,KartRider: Drift,1.009517
4,429470,Space Pilgrim Episode I: Alpha Centauri,1.008157
5,283960,Pajama Sam: No Need to Hide When It's Dark Out...,1.008152
6,1006120,Tetsumo Party,1.007614
7,657630,Fidget Spinner,1.007023
8,825580,Grotoro,1.00629
9,746920,Rapid Tap,1.006117


In [5]:
# Test Fine Tuning on live data
model = GraphSAGE()

data_loader = DataLoader.load_from_file('test_graphsage_data_loader', use_published_models_path=False, load_live_data_loader=True)

model.load('test_graphsage_model', load_published_model=False)
model.new_seed(None)
model.set_data_loader(data_loader)
user_to_fine_tune = 76561198103368250 # 76561198166465514 # 76561198835352289 #  #  #
model.fine_tune(user_to_fine_tune, debug=False)
print('Fine Tuned User Output')
display(get_game_name_and_scores(data_loader, model.score_and_predict_n_games_for_user(user_to_fine_tune)[:10]))
display(get_game_name_and_scores(data_loader, model.score_and_predict_n_games_for_user(user_to_fine_tune)[-10:]))

HeteroData(
  user={ x=[79314, 1] },
  game={ x=[34088, 473] },
  (user, plays, game)={
    edge_index=[2, 21764429],
    edge_label=[21764429],
  },
  (game, rev_plays, user)={ edge_index=[2, 21764429] }
)
HeteroData(
  user={ x=[79315, 1] },
  game={ x=[34088, 473] },
  (user, plays, game)={
    edge_index=[2, 21764921],
    edge_label=[21764921],
  },
  (game, rev_plays, user)={ edge_index=[2, 21764921] }
)
Fine Tuned User Output


Unnamed: 0,id,name,score
0,1113560,NieR Replicant™ ver.1.22474487139...,1.104521
1,524220,NieR:Automata™,1.103208
2,601150,Devil May Cry 5,1.09634
3,678960,CODE VEIN,1.095711
4,582160,Assassin's Creed® Origins,1.083393
5,1096720,CATGIRL LOVER,1.078741
6,883710,Resident Evil 2,1.077821
7,460790,Bayonetta,1.077745
8,262060,Darkest Dungeon®,1.077404
9,752590,A Plague Tale: Innocence,1.076158


Unnamed: 0,id,name,score
0,516790,Fat Mask,0.955767
1,521200,Epic Battle Fantasy 3,0.955488
2,1677740,Stumble Guys,0.954661
3,1184140,KartRider: Drift,0.954395
4,429470,Space Pilgrim Episode I: Alpha Centauri,0.953035
5,283960,Pajama Sam: No Need to Hide When It's Dark Out...,0.95303
6,1006120,Tetsumo Party,0.952493
7,657630,Fidget Spinner,0.951901
8,825580,Grotoro,0.951168
9,746920,Rapid Tap,0.950995
