# Source:
- Pyg highlevel documentation:
    - https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html#feature-store
<br><br>
- Pyg Test of feature store: 
    - https://github.com/pyg-team/pytorch_geometric/blob/901a255346009c7294fd3cc1e825aa441f1dbd4f/torch_geometric/testing/feature_store.py
<br><br>
- Youtube pyg batch video:
    - https://www.youtube.com/watch?v=mz9xYNg9Ofs
    
    
---- 
Actually 2.2.0 tagged version
https://github.com/pyg-team/pytorch_geometric/tree/2.2.0

In [1]:
import sys
import platform
import torch
import torch_geometric

print("Platform", platform.system(), platform.release())
print("Python version",sys.version)

print("torch",torch.__version__)
print("torch_geomeric",torch_geometric.__version__)



  from .autonotebook import tqdm as notebook_tqdm


Platform Darwin 22.3.0
Python version 3.8.12 (default, Jul 12 2022, 16:17:42) 
[Clang 13.1.6 (clang-1316.0.21.2.5)]
torch 1.13.1
torch_geomeric 2.2.0


In [2]:
from typing import Dict, List, Optional, Tuple

from torch import Tensor

## FEATURE STORE
from torch_geometric.data.feature_store import FeatureStore, TensorAttr
from torch_geometric.typing import FeatureTensorType

## GRAPH STORE
from torch_geometric.data.graph_store import EdgeAttr, GraphStore
from torch_geometric.typing import EdgeTensorType

from torch_sparse import SparseTensor

## Implement custom feature and graph store as dictionaries

In [3]:

class MyFeatureStore(FeatureStore):
    def __init__(self):
        super().__init__()
        self.store: Dict[Tuple[str, str], Tensor] = {}

    @staticmethod
    def key(attr: TensorAttr) -> str:
        return (attr.group_name, attr.attr_name)

    def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
        index = attr.index

        # None indices define the obvious index:
        if index is None:
            index = torch.arange(0, tensor.shape[0])

        # Store the index:
        self.store[MyFeatureStore.key(attr)] = (index, tensor)

        return True

    def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
        index, tensor = self.store.get(MyFeatureStore.key(attr), (None, None))
        if tensor is None:
            return None

        # None indices return the whole tensor:
        if attr.index is None:
            return tensor

        # Empty slices return the whole tensor:
        if (isinstance(attr.index, slice)
                and attr.index == slice(None, None, None)):
            return tensor

        idx = (torch.cat([(index == v).nonzero() for v in attr.index]).view(-1)
               if attr.index.numel() > 0 else [])
        return tensor[idx]

    def _remove_tensor(self, attr: TensorAttr) -> bool:
        del self.store[MyFeatureStore.key(attr)]
        return True

    def _get_tensor_size(self, attr: TensorAttr) -> Tuple:
        return self._get_tensor(attr).size()

    def get_all_tensor_attrs(self) -> List[str]:
        return [TensorAttr(*key) for key in self.store.keys()]

    def __len__(self):
        # TODO
        return(1)


In [4]:

class MyGraphStore(GraphStore):
    def __init__(self):
        super().__init__()
        self.store: Dict[EdgeAttr, Tuple[Tensor, Tensor]] = {}

    @staticmethod
    def key(attr: EdgeAttr) -> str:
        return (attr.edge_type, attr.layout.value, attr.is_sorted, attr.size)

    def _put_edge_index(self, edge_index: EdgeTensorType,
                        edge_attr: EdgeAttr) -> bool:
        self.store[MyGraphStore.key(edge_attr)] = edge_index

    def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
        return self.store.get(MyGraphStore.key(edge_attr), None)

    def get_all_edge_attrs(self):
        return [EdgeAttr(*key) for key in self.store]

### Actual tests
- Feature_store:
    - https://github.com/pyg-team/pytorch_geometric/blob/2.2.0/test/data/test_feature_store.py
- Graph_store:
    - https://github.com/pyg-team/pytorch_geometric/blob/2.2.0/test/data/test_graph_store.py

In [5]:
# Feature store - heterogenous
feature_store = MyFeatureStore()
tensor_a = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]])
tensor_b = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]])


group_name_a = 'A'
attr_name_a = 'feat_a'
group_name_b = 'B'
attr_name_b = 'feat_b'

index_a = torch.tensor([0, 1, 2])
index_b = torch.tensor([0, 1, 2])
attr_a = TensorAttr(group_name_a, attr_name_a, index_a)
attr_b = TensorAttr(group_name_b, attr_name_b, index_b)


feature_store.put_tensor(tensor_a, attr_a)
feature_store.put_tensor(tensor_b, attr_b)

# Graph store - heterogenous
graph_store = MyGraphStore()
edge_index_ab = torch.LongTensor([(0, 1), (1, 2), (2,0),(0,2)])
#adj = SparseTensor(row=edge_index_ab[0], col=edge_index_ab[1])
coo = (edge_index_ab[0], edge_index_ab[1])


edge_attr_ab = torch_geometric.data.graph_store.EdgeAttr(
    edge_type = ("A","link_name_ab","B"),
    layout = "csr",
    is_sorted = False,
    size = (2,2))


graph_store.put_edge_index(edge_index = coo,
                           edge_type=('A', '1', 'B'),
                            layout='coo', 
                            size=(2, 2),
                           is_sorted=False
                          )


node_sampler = torch_geometric.sampler.NeighborSampler((feature_store,graph_store), num_neighbors=[1],input_type="csr")


loader = torch_geometric.loader.NodeLoader(
    data=(feature_store, graph_store),
    node_sampler=node_sampler,
    batch_size=1,
    input_nodes='A',
)

for batch in loader:
    pass

IndexError: phmap at(): lookup non-existent key