##### Heterogeneous graph construction

In [1]:
import json
import os
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import GATConv, RGCNConv
import pickle
import pandas as pd
import copy
from tqdm import tqdm

In [2]:
data_dir = "data/beauty/handled"
id_map = json.load(open(os.path.join(data_dir, "id_map.json")))
item2attr = json.load(open(os.path.join(data_dir, "item2attributes.json")))

In [3]:
import json
import torch
from torch_geometric.data import HeteroData

def build_prompt_graph(id_map_path, item2attr_path):
    """
    Constructs a heterogeneous item-attribute graph from item2id and item2attributes JSON files.
    """

    # Load mappings
    id_map = json.load(open(id_map_path))["item2id"]  # maps ASIN to integer ID
    item2attr = json.load(open(item2attr_path))       # maps ASIN to attribute list

    # Keep only items present in both files
    valid_raw_ids = set(item2attr.keys()) & set(id_map.keys())
    print(f"# of matched items: {len(valid_raw_ids)}")

    # Map valid item raw IDs to their dense integer ID
    mapped_attr = {int(id_map[raw_id]): item2attr[raw_id] for raw_id in valid_raw_ids}
    max_item_id = max(mapped_attr.keys())

    # Create graph
    graph = HeteroData()
    graph['item'].num_nodes = max_item_id + 1

    attr2id = {}
    attr_cnt = 0

    # Edge containers
    edge_index_dict = {
        ('item', 'has_brand', 'attribute'): [[], []],
        ('attribute', 'rev_has_brand', 'item'): [[], []],
        ('item', 'has_category', 'attribute'): [[], []],
        ('attribute', 'rev_has_category', 'item'): [[], []],
        ('item', 'has_price', 'attribute'): [[], []],
        ('attribute', 'rev_has_price', 'item'): [[], []],
    }

    # Build edge indices
    for item_id, attrs in mapped_attr.items():
        for attr in attrs:
            if ':' not in attr:
                continue  # skip attributes without prefix (e.g. generic)
            prefix, value = attr.split(':', 1)

            if attr not in attr2id:
                attr2id[attr] = attr_cnt
                attr_cnt += 1

            aid = attr2id[attr]

            if prefix == 'brand':
                edge_index_dict[('item', 'has_brand', 'attribute')][0].append(item_id)
                edge_index_dict[('item', 'has_brand', 'attribute')][1].append(aid)
                edge_index_dict[('attribute', 'rev_has_brand', 'item')][0].append(aid)
                edge_index_dict[('attribute', 'rev_has_brand', 'item')][1].append(item_id)

            elif prefix in ['cat', 'category']:
                edge_index_dict[('item', 'has_category', 'attribute')][0].append(item_id)
                edge_index_dict[('item', 'has_category', 'attribute')][1].append(aid)
                edge_index_dict[('attribute', 'rev_has_category', 'item')][0].append(aid)
                edge_index_dict[('attribute', 'rev_has_category', 'item')][1].append(item_id)

            elif prefix == 'price':
                edge_index_dict[('item', 'has_price', 'attribute')][0].append(item_id)
                edge_index_dict[('item', 'has_price', 'attribute')][1].append(aid)
                edge_index_dict[('attribute', 'rev_has_price', 'item')][0].append(aid)
                edge_index_dict[('attribute', 'rev_has_price', 'item')][1].append(item_id)

    # Assign number of attribute nodes
    graph['attribute'].num_nodes = attr_cnt

    # Convert edge lists to PyTorch tensors
    for rel, (src, dst) in edge_index_dict.items():
        graph[rel].edge_index = torch.tensor([src, dst], dtype=torch.long)

    # Print basic stats
    print(f"Item nodes: {graph['item'].num_nodes}")
    print(f"Attribute nodes: {graph['attribute'].num_nodes}")
    for rel in graph.edge_index_dict:
        print(f"{rel}: {graph[rel].edge_index.shape[1]} edges")

    return graph

In [4]:
graph = build_prompt_graph(data_dir+"/"+"id_map.json", data_dir+"/"+"item2attributes_flat.json")
print(f"Item nodes: {graph['item'].num_nodes}")
print(f"Attribute nodes: {graph['attribute'].num_nodes}")

for rel in graph.edge_index_dict:
    print(f"{rel}: {graph[rel].edge_index.shape[1]} edges")

# of matched items: 57289
Item nodes: 57290
Attribute nodes: 6474
('item', 'has_brand', 'attribute'): 41458 edges
('attribute', 'rev_has_brand', 'item'): 41458 edges
('item', 'has_category', 'attribute'): 57289 edges
('attribute', 'rev_has_category', 'item'): 57289 edges
('item', 'has_price', 'attribute'): 50299 edges
('attribute', 'rev_has_price', 'item'): 50299 edges
Item nodes: 57290
Attribute nodes: 6474
('item', 'has_brand', 'attribute'): 41458 edges
('attribute', 'rev_has_brand', 'item'): 41458 edges
('item', 'has_category', 'attribute'): 57289 edges
('attribute', 'rev_has_category', 'item'): 57289 edges
('item', 'has_price', 'attribute'): 50299 edges
('attribute', 'rev_has_price', 'item'): 50299 edges


In [5]:
hidden_dim = 64
graph['item'].x = torch.randn(graph['item'].num_nodes, hidden_dim)
graph['attribute'].x = torch.randn(graph['attribute'].num_nodes, hidden_dim)

##### prompt GAT

In [6]:
from torch_geometric.nn import GATConv, HeteroConv
import torch.nn as nn

class PromptGAT(torch.nn.Module):
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()
        self.dropout = dropout

        self.convs1 = HeteroConv({
            ('item', 'has_brand', 'attribute'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('attribute', 'rev_has_brand', 'item'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('item', 'has_category', 'attribute'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('attribute', 'rev_has_category', 'item'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('item', 'has_price', 'attribute'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('attribute', 'rev_has_price', 'item'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
        }, aggr='sum')

        self.norm1 = nn.LayerNorm(hidden_dim)

        self.convs2 = HeteroConv({
            ('item', 'has_brand', 'attribute'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('attribute', 'rev_has_brand', 'item'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('item', 'has_category', 'attribute'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('attribute', 'rev_has_category', 'item'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('item', 'has_price', 'attribute'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            ('attribute', 'rev_has_price', 'item'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
        }, aggr='sum')

        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.convs1(x_dict, edge_index_dict)
        x_dict = {k: self.norm1(F.relu(F.dropout(v, p=self.dropout, training=self.training))) for k, v in x_dict.items()}

        x_dict = self.convs2(x_dict, edge_index_dict)
        x_dict = {k: self.norm2(F.relu(F.dropout(v, p=self.dropout, training=self.training))) for k, v in x_dict.items()}

        return x_dict

##### get GNN embedding and use HeteroConv to deal with each type of relationship separately

#### Apply Graph Contrastive Learning


In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.transforms import ToUndirected
from tqdm import trange
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

graph = graph.to(device)
model = PromptGAT(hidden_dim=64, dropout=0.2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# contrastive loss
def contrastive_loss(item_emb, attr_emb, edge_index, num_attr, num_neg=10):
    loss = 0.0
    total = edge_index.shape[1]

    for i in range(total):
        item_idx = edge_index[0, i]
        attr_pos_idx = edge_index[1, i]

        z_i = item_emb[item_idx]          # (64,)
        z_pos = attr_emb[attr_pos_idx]    # (64,)

        neg_indices = torch.randint(0, num_attr, (num_neg,), device=item_emb.device)
        z_neg = attr_emb[neg_indices]     # (num_neg, 64)

        pos_score = torch.exp(F.cosine_similarity(z_i, z_pos, dim=0) / 0.1)  # temperature = 0.1
        neg_score = torch.exp(torch.cosine_similarity(z_i.unsqueeze(0), z_neg, dim=1) / 0.1).sum()

        loss += -torch.log(pos_score / (pos_score + neg_score + 1e-8))

    return loss / total

# train
print("Start contrastive training...")
for epoch in tqdm(trange(100)):
    model.train()
    optimizer.zero_grad()

    out_dict = model(graph.x_dict, graph.edge_index_dict)
    item_emb = out_dict['item']
    attr_emb = out_dict['attribute']

    loss = 0
    for rel in tqdm(['has_brand', 'has_category', 'has_price']):
        
        edge_index = graph[('item', rel, 'attribute')].edge_index
        if edge_index.size(1) == 0:
            continue
        loss += contrastive_loss(item_emb, attr_emb, edge_index, graph['attribute'].num_nodes)

    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss = {loss.item():.4f}")

# save
model.eval()
with torch.no_grad():
    out_dict = model(graph.x_dict, graph.edge_index_dict)
    item_emb = out_dict['item'].cpu().numpy()

with open("data/beauty/handled/gnn_item_emb.pkl", "wb") as f:
    pickle.dump(item_emb, f)

print("GNN embedding saved!")

Start contrastive training...


  0%|          | 0/100 [00:00<?, ?it/s]
[A
[A
[A
100%|██████████| 3/3 [00:10<00:00,  3.45s/it]


##### Future: AutoEncoder / SimCLR?

In [None]:
##### work to be discussed #####

##### some check for dims

In [9]:
import pickle
with open("data/beauty/handled/itm_emb_np.pkl","rb") as f:
    data_llm = pickle.load(f)
data_llm.shape

(57289, 1536)

In [13]:
with open("data/beauty/handled/pca64_itm_emb_np.pkl","rb") as f:
    data_pca = pickle.load(f)
data_pca.shape

(57289, 64)

In [14]:
with open("graph_data/gnn_item_emb.pkl","rb") as f:
    data_gnn = pickle.load(f)
data_gnn.shape

(57289, 64)

##### concatenate collaborative view and gnn view

In [None]:
import numpy as np

# Load LLM PCA64 embedding
with open("data/beauty/handled/pca64_itm_emb_np.pkl", "rb") as f:
    col_emb = pickle.load(f)  # shape: [num_items, 64]

# Load GNN output
with open("graph_data/gnn_itm_emb.pkl", "rb") as f:
    gnn_emb = pickle.load(f)  # shape: [num_items, 64]
print(f"col_shape: {col_emb.shape}")
print(f"gnn_shape: {gnn_emb.shape}")
assert col_emb.shape[0] == gnn_emb.shape[0], "Mismatch in item count"

# Concatenate: [LLM || GNN]
fused_emb = np.concatenate([col_emb, gnn_emb], axis=1)  # shape: [num_items, 128]

# Save fused embedding
with open("graph_data/fused_pca64_itm_emb_np.pkl", "wb") as f:
    pickle.dump(fused_emb, f)


with open("graph_data/fused_pca64_itm_emb_np.pkl", "rb") as f:
    fused_data = pickle.load(f)
print(f"the dimension of fused embedding is {fused_data.shape}")

col_shape: (57289, 64)
gnn_shape: (57289, 64)
the dimension of fused embedding is (57289, 128)


In [22]:
print(f"collobrative embedding item [0]: \n {col_emb[0]}")
print(f"gnn embedding item [0]: \n {gnn_emb[0]}")

collobrative embedding item [0]: 
 [ 0.06981667  0.14442578  0.0535023  -0.08414154  0.09463756 -0.03626085
 -0.03035027  0.00369254  0.00710129 -0.06218467  0.01594733 -0.05304417
  0.08057744  0.00208946 -0.03156804  0.00423898 -0.02121462  0.01415135
  0.08333678  0.0289374  -0.00776773  0.02657525 -0.01858179 -0.00593037
  0.00067705 -0.00235615 -0.02065355 -0.03762729 -0.03606503 -0.03609613
 -0.00611113 -0.00296575 -0.01485604  0.07745654 -0.01674157  0.01220707
 -0.01177268  0.03651463  0.03259253 -0.03159697 -0.00206229  0.00366125
 -0.03166092  0.01766294  0.02784831 -0.00016647 -0.03964143 -0.00597224
  0.05646616  0.01536507  0.01450725  0.01923479  0.00642199  0.0380005
 -0.00581556 -0.01339675 -0.02487696  0.02866836  0.02915967  0.03548247
  0.02508577  0.04911524  0.03779834  0.03975626]
gnn embedding item [0]: 
 [-4.9440339e-03 -2.4531004e-03 -7.0671751e-03  7.7628912e-03
  1.8639162e-03 -6.2251380e-03  6.0575115e-03 -3.1166142e-03
 -6.3754283e-03 -1.4521150e-03 -1.7033

##### Map GNN embedding to the same scale as pca embedding

In [24]:
target_mean = col_emb.mean()
target_std = col_emb.std()

gnn_item_emb = (gnn_emb - gnn_emb.mean()) / gnn_emb.std()
gnn_item_emb = gnn_item_emb * target_std + target_mean

with open("graph_data/gnn_itm_emb_np.pkl", "wb") as f:
    pickle.dump(gnn_item_emb, f)

In [25]:
with open("graph_data/gnn_itm_emb_np.pkl", "rb") as f:
    gnn_emb = pickle.load(f)

print(f"collobrative embedding item [0]: \n {col_emb[0]}")
print(f"gnn embedding item [0]: \n {gnn_emb[0]}")

collobrative embedding item [0]: 
 [ 0.06981667  0.14442578  0.0535023  -0.08414154  0.09463756 -0.03626085
 -0.03035027  0.00369254  0.00710129 -0.06218467  0.01594733 -0.05304417
  0.08057744  0.00208946 -0.03156804  0.00423898 -0.02121462  0.01415135
  0.08333678  0.0289374  -0.00776773  0.02657525 -0.01858179 -0.00593037
  0.00067705 -0.00235615 -0.02065355 -0.03762729 -0.03606503 -0.03609613
 -0.00611113 -0.00296575 -0.01485604  0.07745654 -0.01674157  0.01220707
 -0.01177268  0.03651463  0.03259253 -0.03159697 -0.00206229  0.00366125
 -0.03166092  0.01766294  0.02784831 -0.00016647 -0.03964143 -0.00597224
  0.05646616  0.01536507  0.01450725  0.01923479  0.00642199  0.0380005
 -0.00581556 -0.01339675 -0.02487696  0.02866836  0.02915967  0.03548247
  0.02508577  0.04911524  0.03779834  0.03975626]
gnn embedding item [0]: 
 [-0.04472134 -0.02084337 -0.06507367  0.07708662  0.02053933 -0.05700194
  0.06073893 -0.02720378 -0.05844261 -0.01124798 -0.01365586  0.04902019
  0.06131523  