# Vertical Federated Image Segmentation

Here, we create a Vertical Federated Autoencoder on MNIST as a proof of concept.

## Prepare MNIST Data

Please download the guest/host MNIST dataset from the link below and place it in the project examples/data folder:

- guest data: https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist_guest.zip

- host data: https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist_host.zip
  
The mnist_guest is a simplified version of the MNIST dataset, with a total of ten categories, which are classified into 0-9 10 folders according to labels. The mnist_host has the same images as the mnist_guest, but it is not labeled.

In [3]:
! ls ../../../../examples/data/mnist_guest

0  1  2  3  4  5  6  7	8  9


In [4]:
! ls ../../../../examples/data/mnist_host

not_labeled


## Dataset

In version FATE-1.10, FATE introduces a new base class for datasets called Dataset, which is based on PyTorch's Dataset class. This class allows users to create custom datasets according to their specific needs. The usage is similar to that of PyTorch's Dataset class, with the added requirement of implementing two additional interfaces when using FATE-NN for data reading and training: load() and get_sample_ids().

To create a custom dataset in Hetero-NN, users need to:

- Develop a new dataset class that inherits from the Dataset class
- Implement the \_\_len\_\_() and \_\_getitem\_\_() methods, which are consistent with PyTorch's Dataset usage. The \_\_len\_\_() method should return the length of the dataset, while the \_\_getitem\_\_() method should return the corresponding data at the specified index. **However, please notice that different \_\_getitem\_\_() methods may have different behaviors between different parties. In the guest party(party with labels), _\_getitem\_\_() method return features and labels, while in the host parties(parties without label), _\_getitem\_\_() method return features only.** 
- Implement the load(), get_sample_ids(), get_classes() methods
  
For those unfamiliar with PyTorch's Dataset class, more information can be found in the PyTorch documentation: [Pytorch Dataset Documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

## Customize Bottom/Top Model

Name the model code bottom_net.py,  you can put it directly under federatedml/nn/model_zoo or use the shortcut interface of jupyter notebook: save_to_fate, to save it directly to federatedml/nn/model_zoo. This is the bottom model structure we define for feature extraction.

In [6]:
from pipeline.component.nn import save_to_fate

In [7]:
%%save_to_fate model mnist_encoder.py
import torch as t
from torch import nn
from torch.nn import Module

class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim=4,fc2_input_dim=128):
    #def __init__(self):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        print("encode flatten",x.size())
        x = self.encoder_lin(x)
        print("encode linear",x.size())
        return x

In [35]:
%%save_to_fate model dummy.py
import torch as t
from torch import nn
from torch.nn import Module

class DummyNet(nn.Module):

    def __init__(self):
        super(DummyNet, self).__init__()
        self.fc = t.nn.Linear(1,1)

    def forward(self, x):
        #x = x.reshape([-1, 1])
        x = self.fc(x)
        return x

And this is the top model we define for classification, we named it as top_model.py.

In [47]:
%%save_to_fate model mnist_decoder.py
import torch as t
from torch import nn
from torch.nn import Module

class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim=4,fc2_input_dim=128):
    #def __init__(self):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 3, stride=2, padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = t.sigmoid(x)
        return x

Then, we can use our models & loss in the Hetero-NN MNIST task! The usage is the same as Homo-NN: we specify our model and loss by nn.CustModel and nn.CustLoss interfaces.

## pipeline initialization

Here we define the pipeline to run a hetero task

In [1]:
import os
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HeteroNN
from pipeline.component.hetero_nn import DatasetParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
from pipeline.component.nn import save_to_fate

fate_torch_hook(t)

# bind path to fate name&namespace
fate_project_path = os.path.abspath('/data/projects/fate')
guest = 10000
host = 9999

pipeline_img = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)

guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}

guest_data_path = fate_project_path + '/examples/data/mnist_guest/'
host_data_path = fate_project_path + '/examples/data/mnist_guest/'
pipeline_img.bind_table(name='mnist_guest', namespace='experiment', path=guest_data_path)
pipeline_img.bind_table(name='mnist_host', namespace='experiment', path=host_data_path)

  from .autonotebook import tqdm as notebook_tqdm


{'namespace': 'experiment', 'table_name': 'mnist_host'}

In [2]:
guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_data)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_data)

In [3]:
hetero_nn_0 = HeteroNN(name="hetero_nn_0", epochs=10,
                       interactive_layer_lr=0.01, batch_size=512, task_type='regression', seed=100
                       )
guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)
host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)

# define model
# use cust model here
# our simple classification model:
guest_bottom = t.nn.CustModel(module_name='dummy.py', class_name='DummyNet')

# use cust model here
host_bottom = t.nn.CustModel(module_name='mnist_encoder.py', class_name='Encoder')

# use new top model here
guest_top = t.nn.CustModel(module_name='mnist_decoder.py', class_name='Decoder')

# interactive layer define
interactive_layer = t.nn.InteractiveLayer(out_dim=4, guest_dim=None, host_dim=4)

# add models
guest_nn_0.add_top_model(guest_top)
guest_nn_0.add_bottom_model(guest_bottom)
host_nn_0.add_bottom_model(host_bottom)

# opt, loss
optimizer = t.optim.Adam(lr=0.01) 
#loss = t.nn.CustLoss(loss_module_name='ce', class_name='CrossEntropyLoss')
loss = t.nn.MSELoss()

# use DatasetParam to specify dataset and pass parameters
guest_nn_0.add_dataset(DatasetParam(dataset_name='imagelabel', return_label=True))
host_nn_0.add_dataset(DatasetParam(dataset_name='image', return_label=False))

hetero_nn_0.set_interactive_layer(interactive_layer)
hetero_nn_0.compile(optimizer=optimizer, loss=loss)

2023-09-12 20:04:12.092475: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
pipeline_img.add_component(reader_0)
pipeline_img.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))
#pipeline_img.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.add_component(Evaluation(name='eval_0'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.compile()

<pipeline.backend.pipeline.PipeLine at 0x7f99bc68db80>

In [12]:
pipeline_img.fit()

[32m2023-09-12 20:18:18.339[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202309122018180959410
[0m
[32m2023-09-12 20:18:18.347[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[32m2023-09-12 20:18:19.355[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:01[0m
[0mm2023-09-12 20:18:20.369[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2023-09-12 20:18:20.370[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2023-09-12 20:18:21.381[0m | [1mINFO    

In [60]:
pipeline_img.get_component('hetero_nn_0').get_output_data()  # get result

Unnamed: 0,id,label,predict_result,predict_score,predict_detail,type
0,img_1,0.9450980424880981,0.4692811071872711,0.4692811071872711,{'label': 0.4692811071872711},train
1,img_3,0.0235294122248888,0.13761655986309052,0.13761655986309052,{'label': 0.13761655986309052},train
2,img_4,0.019607843831181526,0.002599237486720085,0.002599237486720085,{'label': 0.002599237486720085},train
3,img_5,0.0,0.026792803779244423,0.026792803779244423,{'label': 0.026792803779244423},train
4,img_6,0.9607843160629272,0.2996892035007477,0.2996892035007477,{'label': 0.2996892035007477},train
...,...,...,...,...,...,...
1304,img_32537,0.003921568859368563,0.04695115610957146,0.04695115610957146,{'label': 0.04695115610957146},train
1305,img_32558,0.0,0.09780637919902802,0.09780637919902802,{'label': 0.09780637919902802},train
1306,img_32563,1.0,0.0337691493332386,0.0337691493332386,{'label': 0.0337691493332386},train
1307,img_32565,0.0,0.4692811071872711,0.4692811071872711,{'label': 0.4692811071872711},train


In [11]:
print("Cole moment")

Cole moment
