In [1]:
import torch
from torch import nn
from torch_geometric.data import NeighborSampler

from modeling import GNN, HybridNetwork, train
from dataset import HybridDataset

## Dataset

In [2]:
h_dataset = HybridDataset("../graphs/restaurants_user_influence2.gpickle",
                        "../graphs/2017-2018_user_network.gpickle",
                        "../datasets/2017-2018_visited_users.csv",
                        split=[0.8, 0.1, 0.1])

Number of restaurants: 29963
Number of neighbors: 491371
converting restaurant graph to pyg graph... done!
Number of users: 559439
Number of friends: 1448583
converting restaurant graph to pyg graph... done!
2033


In [3]:
train_loader = NeighborSampler(h_dataset.res_pyg_graph.edge_index, node_idx=h_dataset.train_index,
                               sizes=[-1, -1], batch_size=h_dataset.train_index.shape[0], shuffle=True)
all_loader = NeighborSampler(h_dataset.res_pyg_graph.edge_index, node_idx=None,
                               sizes=[-1, -1], batch_size=h_dataset.res_pyg_graph.num_nodes, shuffle=False)

In [4]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
h_dataset.to(device)
model = HybridNetwork(h_dataset.num_res_features,
                      h_dataset.num_user_features,
                      64,
                      h_dataset.num_class,
                      num_layers=2,
                      model="GraphSage",
                      aggr="max").to(device)

# model = GNN(h_dataset.num_res_features,
#             64, 
#             is_final=True, 
#             output_size=3, 
#             num_layers=2, 
#             model="GraphSage", 
#             aggr="max").to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()

In [5]:
# from importlib import reload

# import modeling
# reload(modeling)
# from modeling import GNN, HybridNetwork, train

In [6]:
train(model, 250, h_dataset, train_loader, optimizer, loss, device, k=20, all_loader=all_loader)

epoch: 1, Train: 0.3284, Val: 0.3374, Test: 0.3370
epoch: 6, Train: 0.3458, Val: 0.3555, Test: 0.3540
epoch: 11, Train: 0.3521, Val: 0.3334, Test: 0.3520
epoch: 16, Train: 0.4023, Val: 0.4122, Test: 0.3887
epoch: 21, Train: 0.4528, Val: 0.4556, Test: 0.4578
epoch: 26, Train: 0.4894, Val: 0.5090, Test: 0.4855
epoch: 31, Train: 0.5232, Val: 0.5421, Test: 0.5252
epoch: 36, Train: 0.5571, Val: 0.5738, Test: 0.5612
epoch: 41, Train: 0.5910, Val: 0.5885, Test: 0.5946
epoch: 46, Train: 0.6257, Val: 0.6111, Test: 0.6336
epoch: 51, Train: 0.6338, Val: 0.6225, Test: 0.6393
epoch: 56, Train: 0.6385, Val: 0.6242, Test: 0.6446
epoch: 61, Train: 0.6333, Val: 0.6328, Test: 0.6490
epoch: 66, Train: 0.6540, Val: 0.6412, Test: 0.6627
epoch: 71, Train: 0.6506, Val: 0.6389, Test: 0.6603
epoch: 76, Train: 0.6631, Val: 0.6442, Test: 0.6690
epoch: 81, Train: 0.6643, Val: 0.6328, Test: 0.6583
epoch: 86, Train: 0.6786, Val: 0.6529, Test: 0.6857
epoch: 91, Train: 0.6732, Val: 0.6532, Test: 0.6763
epoch: 96, Tra