# Notebook Purpose and Overview

In this notebook, we show the pipeline to load 5 datasets (cora, pubmed, ogbn-arxiv, arxiv-2023 and ogbn-product) and make predicitons for node classificaiton tasks via openai API.

credit: GPT4 helps code generation for this notebook as well as other utils functions

In [None]:
import numpy as np
import os
import openai
from utils.utils import process_and_compare_predictions, load_data, sample_test_nodes, map_arxiv_labels

openai.api_key  = os.environ['OPENAI_API_KEY']

In [None]:
%load_ext autoreload
%autoreload 2

## Define dataset name

In [34]:
# dataset_name = "arxiv_2023"
# dataset_name = "pubmed"
dataset_name = "cora"
# dataset_name = "arxiv"
# dataset_name = "product"

if dataset_name == "arxiv" or dataset_name == "arxiv_2023":
    source = "arxiv"
else:
    source = dataset_name

# use_ori_arxiv_label=False # only for using original Arxiv identifier in system prompting for ogbn-arxiv
arxiv_style="subcategory" # "identifier", "natural language"
include_options = False # set to true to include options in the prompt for arxiv datasets

## Load dataset and raw texts

In [None]:
data, text = load_data(dataset_name, use_text=True, seed=42)
print(data)

if source == "arxiv" and arxiv_style != "subcategory":
    text = map_arxiv_labels(data, text, source, arxiv_style)

options = set(text['label'])

## Sample test indices. Default setting is full test set for cora and arxiv_2023. And 1000 for other datasets. For demonstration purpose, we set sample size to 3.

In [None]:
if dataset_name == "arxiv_2023" or dataset_name == "cora":
    sample_size = len(data.test_id)
else:
    sample_size = 1000

sample_size = 3

node_indices = sample_test_nodes(data, text, sample_size, dataset_name)

print(f"{node_indices = }")

idx_list = list(range(sample_size))

node_index_list = [node_indices[idx] for idx in idx_list]

## Check dataset splits

In [None]:
data.train_mask.sum(), data.val_mask.sum(), data.test_mask.sum()

## Define max number for 1-hop and 2-hop neighbors

In [None]:
if dataset_name == "product":
    max_papers_1 = 40
    max_papers_2 = 10
else:
    max_papers_1 = 20
    max_papers_2 = 5

# Below we test the loaded dataset on two context as outlined in the paper.
- Rich textual context
- Scarce textual context

# Rich textual context: 

For the target node, title and abstract are given for cora, pubmed, ogbn-arxiv and arxiv-2023, product title and product content are given for ogbn-product.

## Zero-shot

In [None]:
mode = "ego"
zero_shot_CoT=False
hop=1
few_shot=False
include_abs=True
include_label=False
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)

In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## Few-shot

In [None]:
mode = "ego"
zero_shot_CoT=False
hop=1
few_shot=False
include_abs=True
include_label=False
few_shot=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)

In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## Zero-shot CoT

In [None]:
mode = "ego"
zero_shot_CoT=True
hop=1
few_shot=False
include_abs=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)

In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 1-hop title+label

In [None]:
mode = "neighbors"
hop=1
include_label=True
include_abs=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)


In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 1-hop title

In [None]:
mode = "neighbors"
hop=1
include_label=False
include_abs=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)


In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 2-hop title+label

In [None]:
mode = "neighbors"
hop=2
include_label=True
include_abs=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)


In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 2-hop title

In [None]:
mode = "neighbors"
hop=2
include_label=False
include_abs=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, options=options)


In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 1-hop attention:

1-hop attention means attention extraction and prediction over 1-hop neighbors. The attentions for test nodes are given under `\attention`.

In [None]:
mode = "neighbors"
zero_shot_CoT=False
hop=1
few_shot=False
include_abs=True
include_label=False
use_attention=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, use_attention=use_attention, options=options)

# Scarce textual context:

only the title of each node is given


## Zero-shot

In [None]:
mode = "ego"
hop=1
zero_shot_CoT=False
few_shot=False
include_abs=False
include_label=False

accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop, include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)


In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 1-hop title+label

In [None]:
mode = "neighbors"
hop = 1
include_label = True


accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2,  hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)


In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 1-hop title

In [None]:
mode = "neighbors"
hop = 1
include_label = False
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs,  hop=hop, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)

In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 2-hop title+label

In [None]:
mode = "neighbors"
hop = 2
include_label = True

accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)

In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 2-hop title

In [None]:
mode = "neighbors"
hop = 2
include_label = False

accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, include_label=include_label, arxiv_style=arxiv_style, include_abs=include_abs,  hop=hop, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot)

In [None]:
print("Returned accuracy:", accuracy)
print("Returned wrong indexes:", wrong_indexes_list)

## 1-hop attention

In [None]:
mode = "neighbors"
hop=1
include_abs=False
include_label=False
use_attention=True
accuracy, wrong_indexes_list = process_and_compare_predictions(node_index_list, data, text, dataset_name=dataset_name, source=source, mode=mode, max_papers_1=max_papers_1, max_papers_2=max_papers_2, hop=hop,  arxiv_style=arxiv_style, include_abs=include_abs, include_options=include_options, zero_shot_CoT=zero_shot_CoT, few_shot=few_shot, use_attention=use_attention, options=options)