# Vertical Federated Image Segmentation

A Vertical Image Segmentation implementation to detect roads in the CamVid Dataset.

CamVid is available here: https://www.kaggle.com/datasets/carlolepelaars/camvid

## Dataset

In version FATE-1.10, FATE introduces a new base class for datasets called Dataset, which is based on PyTorch's Dataset class. This class allows users to create custom datasets according to their specific needs. The usage is similar to that of PyTorch's Dataset class, with the added requirement of implementing two additional interfaces when using FATE-NN for data reading and training: load() and get_sample_ids().

To create a custom dataset in Hetero-NN, users need to:

- Develop a new dataset class that inherits from the Dataset class
- Implement the \_\_len\_\_() and \_\_getitem\_\_() methods, which are consistent with PyTorch's Dataset usage. The \_\_len\_\_() method should return the length of the dataset, while the \_\_getitem\_\_() method should return the corresponding data at the specified index. **However, please notice that different \_\_getitem\_\_() methods may have different behaviors between different parties. In the guest party(party with labels), _\_getitem\_\_() method return features and labels, while in the host parties(parties without label), _\_getitem\_\_() method return features only.** 
- Implement the load(), get_sample_ids(), get_classes() methods
  
For those unfamiliar with PyTorch's Dataset class, more information can be found in the PyTorch documentation: [Pytorch Dataset Documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

## Customize Bottom/Top Model

Name the model code bottom_net.py,  you can put it directly under federatedml/nn/model_zoo or use the shortcut interface of jupyter notebook: save_to_fate, to save it directly to federatedml/nn/model_zoo. This is the bottom model structure we define for feature extraction.

In [176]:
from pipeline.component.nn import save_to_fate

In [214]:
%%save_to_fate model fcn_top.py
import torch
import torch.nn as nn
import torch.optim as optim
from federatedml.util import LOGGER

class FCN_Top(nn.Module):

    def __init__(self, n_class=1):
        super().__init__()
        self.n_class = n_class
        
        self.relu    = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1     = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn2     = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn3     = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn4     = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn5     = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, n_class, kernel_size=1)

        self.linear5 = nn.Linear(25,512*4*4)
        self.linear4 = nn.Linear(50,512*8*8)
        self.linear3 = nn.Linear(75,256*16*16)
        self.linear2 = nn.Linear(150,128*32*32)
        self.linear1 = nn.Linear(200,64*64*64)
        
        self.unflatten5 = nn.Unflatten(dim=1, unflattened_size=(512,4,4))
        self.unflatten4 = nn.Unflatten(dim=1, unflattened_size=(512,8,8))
        self.unflatten3 = nn.Unflatten(dim=1, unflattened_size=(256,16,16))
        self.unflatten2 = nn.Unflatten(dim=1, unflattened_size=(128,32,32))
        self.unflatten1 = nn.Unflatten(dim=1, unflattened_size=(64,64,64))
    
    def forward(self, x):

        x = x.permute(1,0)
        
        x5 = x[0:25]
        x4 = x[25:75]
        x3 = x[75:150]
        x2 = x[150:300]
        x1 = x[300:500]

        x5 = x5.permute(1,0)
        x4 = x4.permute(1,0)
        x3 = x3.permute(1,0)
        x2 = x2.permute(1,0)
        x1 = x1.permute(1,0)

        LOGGER.info(x5.shape)
        
        x5 = self.linear5(x5)
        x5 = self.unflatten5(x5)
        #x5 = x5.unsqueeze(1)
        x4 = self.linear4(x4)
        x4 = self.unflatten4(x4)
        #x4 = x4.unsqueeze(1)
        x3 = self.linear3(x3)
        x3 = self.unflatten3(x3)
        #x3 = x3.unsqueeze(1)
        x2 = self.linear2(x2)
        x2 = self.unflatten2(x2)
        #x2 = x2.unsqueeze(1)
        x1 = self.linear1(x1)
        x1 = self.unflatten1(x1)
        #x1 = x1.unsqueeze(1)

        
        score = self.bn1(self.relu(self.deconv1(x5)))     # size=(N, 512, x.H/16, x.W/16)

        score = score + x4                                # element-wise add, size=(N, 512, x.H/16, x.W/16)
        score = self.bn2(self.relu(self.deconv2(score)))  # size=(N, 256, x.H/8, x.W/8)
        score = score + x3                                # element-wise add, size=(N, 256, x.H/8, x.W/8)
        score = self.bn3(self.relu(self.deconv3(score)))  # size=(N, 128, x.H/4, x.W/4)
        score = score + x2                                # element-wise add, size=(N, 128, x.H/4, x.W/4)
        score = self.bn4(self.relu(self.deconv4(score)))  # size=(N, 64, x.H/2, x.W/2)
        score = score + x1                                # element-wise add, size=(N, 64, x.H/2, x.W/2)
        score = self.bn5(self.relu(self.deconv5(score)))  # size=(N, 32, x.H, x.W)
        score = self.classifier(score)                    # size=(N, n_class, x.H/1, x.W/1)

        score = score.squeeze(1)

        return score  # size=(N, n_class, x.H/1, x.W/1)

In [189]:
%%save_to_fate model dummy.py
import torch as t
from torch import nn
from torch.nn import Module

class DummyNet(nn.Module):

    def __init__(self):
        super(DummyNet, self).__init__()
        self.fc = t.nn.Linear(1,1)

    def forward(self, x):
        #x = x.reshape([-1, 1])
        x = self.fc(x)
        return x

And here is the bottom model

In [190]:
%%save_to_fate model fcn_bottom.py
import torch as t
from torch import nn
from torch.nn import Module
import torch.optim as optim
from torchvision import models
from torchvision.models.vgg import VGG
from federatedml.util import LOGGER

ranges = {
    'vgg11': ((0, 3), (3, 6),  (6, 11),  (11, 16), (16, 21)),
    'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
    'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
    'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}

# cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
cfg = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


class VGGNet(VGG):
    def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
        super().__init__(make_layers(cfg[model]))
        self.ranges = ranges[model]

        if pretrained:
            exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)

        if not requires_grad:
            for param in super().parameters():
                param.requires_grad = False

        if remove_fc:  # delete redundant fully-connected layer params, can save memory
            del self.classifier

        if show_params:
            for name, param in self.named_parameters():
                print(name, param.size())

    def forward(self, x):
        output = {}

        # get the output of each maxpooling layer (5 maxpool in VGG net)
        for idx in range(len(self.ranges)):
            for layer in range(self.ranges[idx][0], self.ranges[idx][1]):
                x = self.features[layer](x)
            output["x%d"%(idx+1)] = x

        return output

class Intermediate(nn.Module):
    def __init__(self):
        super().__init__()
        #x1 = 200      x2 = 150     x3 = 75     x4 = 50     x5 = 25
        self.flatten = nn.Flatten(start_dim=1)
        self.linear5 = nn.Linear(512*4*4,25)
        self.linear4 = nn.Linear(512*8*8,50)
        self.linear3 = nn.Linear(256*16*16,75)
        self.linear2 = nn.Linear(128*32*32,150)
        self.linear1 = nn.Linear(64*64*64,200)
        
    
    def forward(self, x):
        x5 = x['x5']  # size=(N, 512, x.H/32, x.W/32)
        x4 = x['x4']  # size=(N, 512, x.H/16, x.W/16)
        x3 = x['x3']  # size=(N, 256, x.H/8,  x.W/8)
        x2 = x['x2']  # size=(N, 128, x.H/4,  x.W/4)
        x1 = x['x1']  # size=(N, 64, x.H/2,  x.W/2)

        flatten5 = self.flatten(x5)
        linear5 = self.linear5(flatten5)
        linear5 = linear5.permute(1,0)
        flatten4 = self.flatten(x4)
        linear4 = self.linear4(flatten4)
        linear4 = linear4.permute(1,0)
        flatten3 = self.flatten(x3)
        linear3 = self.linear3(flatten3)
        linear3 = linear3.permute(1,0)
        flatten2 = self.flatten(x2)
        linear2 = self.linear2(flatten2)
        linear2 = linear2.permute(1,0)
        flatten1 = self.flatten(x1)
        linear1 = self.linear1(flatten1)
        linear1 = linear1.permute(1,0)
        LOGGER.info(linear5.shape)
        LOGGER.info(linear4.shape)
        res = t.cat((linear5,linear4,linear3,linear2,linear1),0)

        res = res.permute(1,0)

        return res

class FCN_Bottom(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = VGGNet(requires_grad=True, remove_fc=True)
        self.reduc = Intermediate()
    def forward(self,x):
        x = x[1]
        x = self.vgg(x)
        #LOGGER.info('x5:' + str(x['x5'].shape) + 'x4:' + str(x['x4'].shape) + 'x3:' + str(x['x3'].shape) + 'x2:' + str(x['x2'].shape) + 'x1:' + str(x['x1'].shape))
        x = self.reduc(x)
        #LOGGER.info('output')
        #LOGGER.info(x.size())
        return x
        

In [191]:
#bce = t.nn.BCEWithLogitsLoss()
bce = t.nn.BCELoss()

In [192]:
x = torch.ones([1, 2,2])
y = torch.ones([1, 2,2])

In [193]:
bce(x,y)

tensor(0.)

Then, we can use our models & loss in the Hetero-NN MNIST task! The usage is the same as Homo-NN: we specify our model and loss by nn.CustModel and nn.CustLoss interfaces.

## pipeline initialization

Here we define the pipeline to run a hetero task

In [221]:
import os
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HeteroNN
from pipeline.component.hetero_nn import DatasetParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
from pipeline.component.nn import save_to_fate

fate_torch_hook(t)

# bind path to fate name&namespace
fate_project_path = os.path.abspath('/data/projects/fate')
guest = 10000
host = 9999

pipeline_img = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)

guest_data = {"name": "fcn_guest", "namespace": "experiment"}
host_data = {"name": "fcn_host", "namespace": "experiment"}

guest_data_path = fate_project_path + '/examples/data/CamVid/train_labels/'
host_data_path = fate_project_path + '/examples/data/CamVid/train/'
pipeline_img.bind_table(name='fcn_guest', namespace='experiment', path=guest_data_path)
pipeline_img.bind_table(name='fcn_host', namespace='experiment', path=host_data_path)

{'namespace': 'experiment', 'table_name': 'fcn_host'}

In [222]:
guest_data = {"name": "fcn_guest", "namespace": "experiment"}
host_data = {"name": "fcn_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_data)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_data)

In [223]:
hetero_nn_0 = HeteroNN(name="hetero_nn_0", epochs=1,
                       interactive_layer_lr=0.01, batch_size=16, task_type='regression', seed=100
                       )
guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)
host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)

# define model
# use cust model here
# our simple classification model:
guest_bottom = t.nn.CustModel(module_name='dummy.py', class_name='DummyNet')

# use cust model here
host_bottom = t.nn.CustModel(module_name='fcn_bottom.py', class_name='FCN_Bottom')

# use new top model here
guest_top = t.nn.CustModel(module_name='fcn_top.py', class_name='FCN_Top')

# interactive layer define
interactive_layer = t.nn.InteractiveLayer(out_dim=500, guest_dim=None, host_dim=500)

# add models
guest_nn_0.add_top_model(guest_top)
guest_nn_0.add_bottom_model(guest_bottom)
host_nn_0.add_bottom_model(host_bottom)

# opt, loss
optimizer = t.optim.SGD(guest_top.parameters(), lr=0.1, momentum=0.9)
loss = t.nn.BCEWithLogitsLoss()

# use DatasetParam to specify dataset and pass parameters
guest_nn_0.add_dataset(DatasetParam(dataset_name='vertseg', return_label=True))
host_nn_0.add_dataset(DatasetParam(dataset_name='vertimg', return_label=False))

hetero_nn_0.set_interactive_layer(interactive_layer)
hetero_nn_0.compile(optimizer=optimizer, loss=loss)

In [224]:
pipeline_img.add_component(reader_0)
pipeline_img.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))
#pipeline_img.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.add_component(Evaluation(name='eval_0'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.compile()

<pipeline.backend.pipeline.PipeLine at 0x7f5c3a9402b0>

In [None]:
pipeline_img.fit()

[32m2023-09-27 23:52:39.420[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202309272352391699800
[0m
[32m2023-09-27 23:52:39.427[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[0mm2023-09-27 23:52:40.446[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2023-09-27 23:52:40.447[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:01[0m
[32m2023-09-27 23:52:41.471[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2023-09-27 23:52:42.484[0m | [1mI

In [60]:
pipeline_img.get_component('hetero_nn_0').get_output_data()  # get result

Unnamed: 0,id,label,predict_result,predict_score,predict_detail,type
0,img_1,0.9450980424880981,0.4692811071872711,0.4692811071872711,{'label': 0.4692811071872711},train
1,img_3,0.0235294122248888,0.13761655986309052,0.13761655986309052,{'label': 0.13761655986309052},train
2,img_4,0.019607843831181526,0.002599237486720085,0.002599237486720085,{'label': 0.002599237486720085},train
3,img_5,0.0,0.026792803779244423,0.026792803779244423,{'label': 0.026792803779244423},train
4,img_6,0.9607843160629272,0.2996892035007477,0.2996892035007477,{'label': 0.2996892035007477},train
...,...,...,...,...,...,...
1304,img_32537,0.003921568859368563,0.04695115610957146,0.04695115610957146,{'label': 0.04695115610957146},train
1305,img_32558,0.0,0.09780637919902802,0.09780637919902802,{'label': 0.09780637919902802},train
1306,img_32563,1.0,0.0337691493332386,0.0337691493332386,{'label': 0.0337691493332386},train
1307,img_32565,0.0,0.4692811071872711,0.4692811071872711,{'label': 0.4692811071872711},train


In [11]:
print("Cole moment")

Cole moment
