# Source:
- Pyg highlevel documentation:
    - https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html#feature-store
<br><br>
- Pyg Test of feature store: 
    - https://github.com/pyg-team/pytorch_geometric/blob/901a255346009c7294fd3cc1e825aa441f1dbd4f/torch_geometric/testing/feature_store.py
<br><br>
- Youtube pyg batch video:
    - https://www.youtube.com/watch?v=mz9xYNg9Ofs
    
    
---- 
Actually 2.2.0 tagged version
https://github.com/pyg-team/pytorch_geometric/tree/2.2.0

In [7]:
import sys
import platform
import torch
import torch_geometric

print("Platform", platform.system(), platform.release())
print("Python version",sys.version)

print("torch",torch.__version__)
print("torch_geomeric",torch_geometric.__version__)



Platform Darwin 22.3.0
Python version 3.8.12 (default, Jul 12 2022, 16:17:42) 
[Clang 13.1.6 (clang-1316.0.21.2.5)]
torch 1.13.1
torch_geomeric 2.2.0


### Just copies this test
- https://github.com/pyg-team/pytorch_geometric/blob/master/test/loader/test_neighbor_loader.py#L352

- HeteroData has dummy in memomory Implementations of all interfaces. (ie. also graphstore, feature_store)
    - https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/data/hetero_data.py

In [5]:
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import NeighborLoader
from torch_sparse import SparseTensor

In [6]:

# Initialize feature store, graph store, and reference:
feature_store = HeteroData()
graph_store = HeteroData()
data = HeteroData()

x = torch.arange(100)
data['paper'].x = x
feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)

x = torch.arange(100, 300)
data['author'].x = x
feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)

def get_edge_index(num_src_nodes, num_dst_nodes, num_edges, dtype=torch.int64):
    row = torch.randint(num_src_nodes, (num_edges, ), dtype=dtype)
    col = torch.randint(num_dst_nodes, (num_edges, ), dtype=dtype)
    return torch.stack([row, col], dim=0)

# COO:
edge_index = get_edge_index(100, 100, 500)
data['paper', 'to', 'paper'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
                           edge_type=('paper', 'to', 'paper'),
                           layout='coo', size=(100, 100))

# CSR:
edge_index = get_edge_index(100, 200, 1000)
data['paper', 'to', 'author'].edge_index = edge_index
csr = SparseTensor.from_edge_index(edge_index).csr()[:2]
graph_store.put_edge_index(edge_index=csr,
                           edge_type=('paper', 'to', 'author'),
                           layout='csr', size=(100, 200))

# CSC:
edge_index = get_edge_index(200, 100, 1000)
data['author', 'to', 'paper'].edge_index = edge_index
csc = SparseTensor(row=edge_index[1], col=edge_index[0]).csr()[-2::-1]
graph_store.put_edge_index(edge_index=csc,
                           edge_type=('author', 'to', 'paper'),
                           layout='csc', size=(200, 100))

# COO (sorted):
edge_index = get_edge_index(200, 200, 100)
edge_index = edge_index[:, edge_index[1].argsort()]
data['author', 'to', 'author'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
                           edge_type=('author', 'to', 'author'),
                           layout='coo', size=(200, 200), is_sorted=True)

# Construct neighbor loaders:
loader = NeighborLoader((feature_store, graph_store), batch_size=20,
                         input_nodes=('paper', range(100)),
                         num_neighbors=[-1] * 2)


for batch in loader:
    pass
    # TODO, actual training here

In [8]:
coo

(tensor([135,  58, 119,  16,  37, 128,  63,   2,  99,  23, 161, 132,  29, 164,
         190, 130, 131, 150,  80, 168, 170,  91, 180, 121, 117, 199, 189, 165,
         104, 193,  70, 100,  57, 156,  75,  31,  23,   8,  49,  74, 133,  56,
          66, 148,  42, 187, 150, 144,  91, 171, 123, 111, 167, 146,  99,  23,
         186,  40,  21,  30, 164, 143,  21,  90, 130, 191,  99,   0,  90, 172,
         138,  42,  32, 191,  94,  74,  24, 157, 195, 115, 156,  11, 108, 182,
         132, 170, 166, 119,   3,  88, 131,  15, 182, 136,   4, 165, 127, 121,
         108, 168]),
 tensor([  4,   8,  12,  15,  18,  18,  24,  26,  30,  32,  34,  38,  39,  40,
          41,  46,  46,  48,  49,  50,  52,  53,  54,  54,  56,  57,  58,  58,
          58,  59,  59,  61,  62,  63,  64,  65,  67,  69,  72,  72,  73,  74,
          74,  74,  79,  79,  84,  85,  89,  89,  90,  91,  96,  97,  98,  99,
         104, 104, 105, 106, 109, 110, 111, 115, 122, 123, 123, 124, 125, 130,
         132, 135, 138, 143, 14

In [9]:
csr

(tensor([   0,    7,   18,   27,   35,   41,   54,   68,   74,   86,  100,  108,
          120,  133,  140,  149,  156,  164,  174,  181,  188,  199,  208,  223,
          228,  239,  249,  257,  269,  275,  285,  303,  309,  322,  332,  343,
          351,  364,  371,  375,  385,  395,  409,  423,  428,  440,  459,  469,
          478,  486,  497,  515,  527,  542,  547,  554,  567,  574,  581,  597,
          604,  615,  628,  636,  647,  654,  661,  665,  676,  684,  693,  706,
          721,  730,  739,  750,  763,  775,  779,  785,  800,  810,  820,  833,
          847,  859,  866,  872,  881,  892,  898,  909,  921,  931,  941,  958,
          963,  973,  984,  991, 1000]),
 tensor([ 22,  48,  50,  82, 117, 131, 195,  13,  30,  37,  47,  64, 105, 112,
         134, 148, 161, 165,  19,  30,  59,  62,  68,  89, 106, 129, 138,   8,
          13,  27,  49,  96, 112, 120, 121,  20,  42,  42,  71, 110, 131,   4,
          51,  54,  72,  72,  89,  91, 106, 115, 144, 176, 189, 189,   1, 

In [11]:
csc

(tensor([ 24,  86, 102, 125, 156, 174,  18,  21,  45,  76,  84,  90, 111, 111,
         139, 164,  37,  39,  55,  70,  80,  92, 114, 115, 139, 154, 166, 181,
          29,  35,  48, 147, 185,   4,  24,  40,  65,  71,  89, 102, 116, 119,
         128, 132, 154, 156, 182, 187, 192,   1,   3,  28,  46,  49,  72,  83,
          88, 118, 130, 132, 144, 164, 182, 189, 194,  16,  31,  36,  37,  39,
          82,  89, 160, 166, 171, 178, 183,   4,  28,  40,  43,  43,  64,  85,
          86,  95, 137, 149, 154, 168, 182, 190, 192, 199,  10,  70,  75, 123,
         129, 145,   4,  25,  53,  71,  88,  95, 111, 169, 174,   4,  44,  53,
          64,  77,  81,  90, 113, 166,   6,  44,  45,  75,  77,  83,  91, 116,
         133, 160, 173, 190, 198,  20,  49,  68, 107, 131, 132, 138, 143, 179,
         189,  29,  34,  36,  40,  47,  80,  83, 123, 153, 160,  13,  32,  37,
          39, 155, 173,  34,  99, 127, 151, 178, 179,   5,  12,  19,  43,  49,
         120, 131, 153, 155, 171, 187, 191, 199,   2