# Source:
- Pyg highlevel documentation:
    - https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html#feature-store
<br><br>
- Pyg Test of feature store: 
    - https://github.com/pyg-team/pytorch_geometric/blob/901a255346009c7294fd3cc1e825aa441f1dbd4f/torch_geometric/testing/feature_store.py
<br><br>
- Youtube pyg batch video:
    - https://www.youtube.com/watch?v=mz9xYNg9Ofs
    
    
---- 
Actually 2.2.0 tagged version
https://github.com/pyg-team/pytorch_geometric/tree/2.2.0

In [7]:
import sys
import platform
import torch
import torch_geometric

print("Platform", platform.system(), platform.release())
print("Python version",sys.version)

print("torch",torch.__version__)
print("torch_geomeric",torch_geometric.__version__)



Platform Darwin 22.3.0
Python version 3.8.12 (default, Jul 12 2022, 16:17:42) 
[Clang 13.1.6 (clang-1316.0.21.2.5)]
torch 1.13.1
torch_geomeric 2.2.0


### Just copies this test
- https://github.com/pyg-team/pytorch_geometric/blob/master/test/loader/test_neighbor_loader.py#L352

- HeteroData has dummy in memomory Implementations of all interfaces. (ie. also graphstore, feature_store)
    - https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/data/hetero_data.py

In [5]:
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import NeighborLoader
from torch_sparse import SparseTensor

In [6]:

# Initialize feature store, graph store, and reference:
feature_store = HeteroData()
graph_store = HeteroData()
data = HeteroData()

x = torch.arange(100)
data['paper'].x = x
feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)

x = torch.arange(100, 300)
data['author'].x = x
feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)

def get_edge_index(num_src_nodes, num_dst_nodes, num_edges, dtype=torch.int64):
    row = torch.randint(num_src_nodes, (num_edges, ), dtype=dtype)
    col = torch.randint(num_dst_nodes, (num_edges, ), dtype=dtype)
    return torch.stack([row, col], dim=0)

# COO:
edge_index = get_edge_index(100, 100, 500)
data['paper', 'to', 'paper'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
                           edge_type=('paper', 'to', 'paper'),
                           layout='coo', size=(100, 100))

# CSR:
edge_index = get_edge_index(100, 200, 1000)
data['paper', 'to', 'author'].edge_index = edge_index
csr = SparseTensor.from_edge_index(edge_index).csr()[:2]
graph_store.put_edge_index(edge_index=csr,
                           edge_type=('paper', 'to', 'author'),
                           layout='csr', size=(100, 200))

# CSC:
edge_index = get_edge_index(200, 100, 1000)
data['author', 'to', 'paper'].edge_index = edge_index
csc = SparseTensor(row=edge_index[1], col=edge_index[0]).csr()[-2::-1]
graph_store.put_edge_index(edge_index=csc,
                           edge_type=('author', 'to', 'paper'),
                           layout='csc', size=(200, 100))

# COO (sorted):
edge_index = get_edge_index(200, 200, 100)
edge_index = edge_index[:, edge_index[1].argsort()]
data['author', 'to', 'author'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
                           edge_type=('author', 'to', 'author'),
                           layout='coo', size=(200, 200), is_sorted=True)

# Construct neighbor loaders:
loader = NeighborLoader((feature_store, graph_store), batch_size=20,
                         input_nodes=('paper', range(100)),
                         num_neighbors=[-1] * 2)


for batch in loader:
    pass
    # TODO, actual training here