In [None]:
!pip install torch-geometric torch-sparse torch-scatter

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: torch-sparse, torch-scatter


In [None]:
import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

In [None]:
# Extract restaurant set
restaurant_set = set()
listres = []
for kw in train_data['np2rests'].keys():
    listres.extend(train_data['np2rests'][kw].keys())
restaurant_set = set(listres)


In [None]:
# Convert sets to lists for indexing
keyword_set = list(keyword_set)
restaurant_set = list(restaurant_set)
restaurants = len(listres)
num_keywords = len(keyword_set)
num_restaurants = len(restaurant_set)


In [None]:
# Extract users
train_users, train_users2kw = extract_users(train_data['np2users'])
num_users = len(train_users)

In [None]:
# Create heterogeneous graph
data = HeteroData()

# Node features (simple one-hot encodings for simplicity)
data['user'].x = torch.eye(num_users)
data['keyword'].x = torch.eye(num_keywords)
data['restaurant'].x = torch.eye(num_restaurants)

# Edges: user -> keyword
edge_index_user_keyword = [[], []]
for user_idx, kws in enumerate(train_users2kw):
    for kw in kws:
        if kw in keyword_set:
            kw_idx = keyword_set.index(kw)
            edge_index_user_keyword[0].append(user_idx)
            edge_index_user_keyword[1].append(kw_idx)
data['user', 'interacts', 'keyword'].edge_index = torch.tensor(edge_index_user_keyword, dtype=torch.long)

In [None]:
# Edges: keyword -> restaurant
edge_index_keyword_restaurant = [[], []]
for kw in train_data['np2rests'].keys():
    kw_idx = keyword_set.index(kw)
    for res in train_data['np2rests'][kw].keys():
        res_idx = restaurant_set.index(res)
        edge_index_keyword_restaurant[0].append(kw_idx)
        edge_index_keyword_restaurant[1].append(res_idx)
data['keyword', 'describes', 'restaurant'].edge_index = torch.tensor(edge_index_keyword_restaurant, dtype=torch.long)

In [None]:
# Define GNN model
class GNNRecommender(torch.nn.Module):
    def __init__(self, hidden_dim=64):
        super(GNNRecommender, self).__init__()
        self.conv1 = GCNConv(num_keywords, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim, num_restaurants)

    def forward(self, data):
        x, edge_index = data['keyword'].x, data['keyword', 'describes', 'restaurant'].edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.fc(x)
        return x

# Initialize and train model
model = GNNRecommender()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    out = model(data)
    # Dummy loss (replace with supervised loss if labels available)
    loss = F.mse_loss(out, torch.zeros_like(out))  # Placeholder
    loss.backward()
    optimizer.step()

In [None]:
# Generate restaurant scores
model.eval()
with torch.no_grad():
    keyword_embeddings = model(data)
    a = F.softmax(keyword_embeddings, dim=1).numpy()  # Shape: (num_keywords, num_restaurants)