Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

knn retrieval evaluation #770

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions examples/knn_retriever/build_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch as th
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a license head for each python code.

import time
import graphstorm as gs
from graphstorm.utils import is_distributed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to move import graphstorm related code together.
And for import, usually the order will be:

import system/builtin libraries like os, time, etc.
import pip packages
import local codes.

import faiss
import dgl
import numpy as np
from collections import defaultdict
from graphstorm.config import get_argument_parser
from graphstorm.config import GSConfig
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.dataloading import GSgnnNodeTrainData
from graphstorm.utils import setup_device
from graphstorm.model.utils import load_gsgnn_embeddings

def calculate_recall(pred, ground_truth):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a description of how you compute recall in the function doc?

# Convert list_data to a set if it's not already a set
if not isinstance(pred, set):
pred = set(pred)

overlap = len(pred & ground_truth)
#if overlap > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the comments.

# return 1
#else:
# return 0
return overlap / len(ground_truth)

def main(config_args):
""" main function
"""
config = GSConfig(config_args)
embs = load_gsgnn_embeddings(config.save_embed_path)

index_dimension = embs[config.target_ntype].size(1)
# Number of clusters (higher values lead to better recall but slower search)
#nlist = 750
#quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the commented codes.

#index = faiss.IndexIVFFlat(quantizer, index_dimension, nlist, faiss.METRIC_INNER_PRODUCT)
#index.train(embs[config.target_ntype])
index = faiss.IndexFlatIP(index_dimension)
index.add(embs[config.target_ntype])

#print(scores.abs().mean())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


gs.initialize(ip_config=config.ip_config, backend=config.backend)
device = setup_device(config.local_rank)
#index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(embedding_size))
# Define the training dataset
train_data = GSgnnNodeTrainData(
config.graph_name,
config.part_config,
train_ntypes=config.target_ntype,
eval_ntypes=config.eval_target_ntype,
label_field=None,
node_feat_field=None,
)
#for i in range(embs[config.target_ntype].shape[0]):
# print(embs[config.target_ntype][i,:].sum(), train_data.g.ndata['bert_h'][i].sum())
# breakpoint()
# embs[config.target_ntype][i,:] = train_data.g.ndata['bert_h'][i]

#print( train_data.g.ndata['bert_h'][0,:], embs[config.target_ntype][0,:])
#print(train_data.g.ndata['bert_h'])

# TODO: devise a dataloader that can exclude targets and add train_mask like LP Loader
test_dataloader = GSgnnNodeDataLoader(
train_data,
train_data.train_idxs,
fanout=[-1],
batch_size=config.eval_batch_size,
device=device,
train_task=False,
)
dataloader_iter = iter(test_dataloader)
len_dataloader = max_num_batch = len(test_dataloader)
tensor = th.tensor([len_dataloader], device=device)
if is_distributed():
th.distributed.all_reduce(tensor, op=th.distributed.ReduceOp.MAX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make it distributed?

max_num_batch = tensor[0]
recall = []
max_ = []
for iter_l in range(max_num_batch):
ground_truth = defaultdict(set)
input_nodes, seeds, blocks = next(dataloader_iter)
#block_graph = dgl.block_to_graph(blocks[0])
src_id = blocks[0].srcdata[dgl.NID].tolist()
dst_id = blocks[0].dstdata[dgl.NID].tolist()
#print(blocks[0].edges(form='uv', etype='also_buy'))
#breakpoint()
# print(dgl.NID)
if 'also_buy' in blocks[0].etypes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implemented specifically for amazon review?

#src, dst = block_graph.edges(form='uv', etype='also_buy')
src, dst = blocks[0].edges(form='uv', etype='also_buy')
for s,d in zip(src.tolist(),dst.tolist()):
ground_truth[dst_id[d]].add(src_id[s])
#ground_truth[src_id[s]].add(dst_id[d])
if 'also_buy-rev' in blocks[0].etypes:
#src, dst = block_graph.edges(form='uv', etype='also_buy-rev')
src, dst = blocks[0].edges(form='uv', etype='also_buy-rev')
for s,d in zip(src.tolist(),dst.tolist()):
ground_truth[dst_id[d]].add(src_id[s])
#ground_truth[src_id[s]].add(dst_id[d])
query_idx = list(ground_truth.keys())
#print(ground_truth)
#breakpoint()
ddd,lll = index.search(embs[config.target_ntype][query_idx],100 + 1)
#knn_result = lll.tolist()

for idx,query in enumerate(query_idx):
recall.append(calculate_recall(lll[idx, 1:], ground_truth[query]))
max_.append(query)
#print(recall)
if gs.get_rank() == 0:
#print(query_idx, lll)
#print(max_num_batch, len(recall), np.mean(recall))
print(f'recall@100: {np.mean(recall)}')

def generate_parser():
"""Generate an argument parser"""
parser = get_argument_parser()
return parser

if __name__ == "__main__":
arg_parser = generate_parser()

args = arg_parser.parse_args()
print(args)
main(args)
46 changes: 46 additions & 0 deletions examples/knn_retriever/embedding_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
gsf:
basic:
backend: gloo
verbose: false
save_perf_results_path: null
gnn:
model_encoder_type: mlp
fanout: "5,5"
node_feat_name:
- item:bert_h
num_layers: 2
hidden_size: 768
use_mini_batch_infer: true
input:
restore_model_path: null
output:
save_model_path: null
save_embed_path: /shared_data/graphstorm/examples/peft_llm_gnn/results/lp/Video_Games
hyperparam:
dropout: 0.
lr: 0.001
num_epochs: 1
batch_size: 512
eval_batch_size: 512
wd_l2norm: 0.00001
no_validation: false
rgcn:
num_bases: -1
use_self_loop: true
lp_decoder_type: dot_product
sparse_optimizer_lr: 1e-2
use_node_embeddings: false
link_prediction:
num_negative_edges: 1
num_negative_edges_eval: 100
contrastive_loss_temperature: 0.1
lp_loss_func: contrastive
lp_embed_normalizer: l2_norm
train_negative_sampler: inbatch_joint
target_ntype: item
eval_etype:
- "item,also_buy,item"
train_etype:
- "item,also_buy,item"
exclude_training_targets: true
reverse_edge_types_map: ["item,also_buy,also_buy-rev,item"]
18 changes: 18 additions & 0 deletions examples/knn_retriever/run_knn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WORKSPACE=/shared_data/graphstorm/examples/knn_retriever/
DATASPACE=/shared_data/graphstorm/examples/peft_llm_gnn/
dataset=amazon_review
domain=$1

python -m graphstorm.run.launch \
--workspace "$WORKSPACE" \
--part-config "$DATASPACE"/datasets/amazon_review_"$domain"/amazon_review.json \
--ip-config "$DATASPACE"/ip_list.txt \
--num-trainers 1 \
--num-servers 1 \
--num-samplers 0 \
--ssh-port 22 \
--do-nid-remap False \
build_index.py \
--cf "$WORKSPACE"/embedding_config.yaml \
--save-model-path "$DATASPACE"/model/lp/"$domain"/ \
--save-embed-path "$DATASPACE"/results/lp/"$domain"/
2 changes: 1 addition & 1 deletion examples/peft_llm_gnn/AR_Video_Games.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"transform": {"name": "bert_hf",
"bert_model": "bert-base-uncased",
"infer_batch_size": 128,
"max_seq_length": 32}
"max_seq_length": 128}
}
],
"labels": [
Expand Down
2 changes: 1 addition & 1 deletion examples/peft_llm_gnn/lp_config_Video_Games.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ gsf:
hyperparam:
dropout: 0.
lr: 0.0001
num_epochs: 3
num_epochs: 4
batch_size: 16
eval_batch_size: 16
wd_l2norm: 0.00001
Expand Down
7 changes: 5 additions & 2 deletions examples/peft_llm_gnn/main_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def main(config_args):
save_model_frequency=config.save_model_frequency,
use_mini_batch_infer=True
)

# Load the best checkpoint
best_model_path = trainer.get_best_model_path()
model.restore_model(best_model_path)
Expand All @@ -123,7 +123,10 @@ def main(config_args):
# Run inference on the inference dataset and save the GNN embeddings in the specified path.
infer.infer(train_data, test_dataloader, save_embed_path=config.save_embed_path,
edge_mask_for_gnn_embeddings='train_mask',
use_mini_batch_infer=True, infer_batch_size=config.eval_batch_size)
use_mini_batch_infer=True,
node_id_mapping_file=config.node_id_mapping_file,
save_embed_format=config.save_embed_format,
infer_batch_size=config.eval_batch_size)

def generate_parser():
"""Generate an argument parser"""
Expand Down
27 changes: 27 additions & 0 deletions python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import dgl
import pandas as pd

from ..config import GRAPHSTORM_LP_EMB_L2_NORMALIZATION
from ..gconstruct.file_io import stream_dist_tensors_to_hdf5
Expand Down Expand Up @@ -1065,6 +1066,32 @@ def save_full_node_embeddings(g, save_embed_path,

save_shuffled_node_embeddings(shuffled_embs, save_embed_path, save_embed_format)

def load_gsgnn_embeddings(emb_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if load_pytorch_embedding is useful?

'''Load from `save_full_node_embeddings` to a dict of DistTensor's
'''
with open(os.path.join(emb_path, "emb_info.json"), 'r', encoding='utf-8') as f:
emb_info = json.load(f)
embs = {}
for ntype in emb_info["emb_name"]:
path = os.path.join(emb_path, ntype)
ntype_emb_files = os.listdir(path)
nid_files = [fname for fname in ntype_emb_files \
if fname.startswith("embed_nids-") and fname.endswith("pt")]
emb_files = [fname for fname in ntype_emb_files \
if fname.startswith("embed-") and fname.endswith("pt")]
num_parts = len(emb_files)
embeddings_list = []
nid_list = []
for i in range(num_parts):
embeddings_list.append(th.load(os.path.join(path, emb_files[i])))
nid_list.append(th.load(os.path.join(path, nid_files[i])))
# Convert the list of embeddings to a PyTorch tensor
embeddings_tensor = th.cat(embeddings_list, dim=0)
nids_tensor = th.cat(nid_list, dim=0)
result_tensor = th.zeros_like(embeddings_tensor)
result_tensor[nids_tensor] = embeddings_tensor
embs[ntype] = result_tensor
return embs

def save_embeddings(emb_path, embeddings, rank, world_size,
device=th.device('cpu'), node_id_mapping_file=None,
Expand Down
Loading