# Homo-Graph Customized Model

In FATE 1.10, we integrated Torch-geometric 2.2 into the FATE framework with which you can build Graph Neural Networks (GNN) in a homo federated way. Homo-graph is an extension of the customized model, but there are some differences in terms of input data and trainer.

## Install Torch-geometric

For the installation please refer to [torch-geometric web site](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html), or run the following bash command in your terminal for quick installation.

pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.13.0+cpu.html

## Cora Dataset

Cora is a graph dataset for multiple node classification. It has 2708 nodes and 10k edges. Each node has 1433 features. 

In federated homo graph modeling, each party holds their own graph dataset with the same features, i.e. horizontal federation. The nodes in the two graphs may not overlap and the parties do not exchange any information about their graph datasets duraing modeling.

For simplicity, the host and the guest in the demo have the same Cora dataset. The train/validation/test is divided in the following way:

train: [0:140]
validation: [200:500]
test: [500:1500]

The preprocessed data can be find in examples/data/cora4fate.


## Upload data

cd {FATE project}/examples/pipeline/upload
python pipeline-upload-graph-cora.py {FATE project}/examples/data



## GraphSage Model

Name the model code as homegraphsage.py. You can put it directly under federatedml/nn/model_zoo or use the shortcut interface of jupyter notebook to save it directly to federatedml/nn/model_zoo

In [None]:
from pipeline.component.nn import save_to_fate

In [None]:
%%save_to_fate model graphsage.py

import torch as t
from torch import nn
from torch.nn import Module
import torch_geometric.nn as pyg


class Sage(nn.Module):
    def __init__(self, in_channels, hidden_channels, class_num):
        super().__init__()
        self.model = nn.ModuleList([
            pyg.SAGEConv(in_channels=in_channels, out_channels=hidden_channels, project=True),
            pyg.SAGEConv(in_channels=hidden_channels, out_channels=class_num),
            nn.LogSoftmax()]
        )

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.model):
            if isinstance(conv, pyg.SAGEConv):
                x = conv(x, edge_index)
            else:
                x = conv(x)
        return x 

In [None]:
homosage = Sage(in_channels=1433, hidden_channels=64, class_num=7)
homosage

## Submit a Homo-NN task with Custom Model

cd {FATE project}/examples/pipeline/homo_graph
python pipeline_homo_graph_sage.py