# 在SecretFlow中使用自定义DataBuilder（Torch）

The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production.

本教程将展示下，怎样在SecretFlow的多方安全环境中，如何使用自定义DataBuilder模式加载数据，并训练模型。
本教程将使用Flower数据集的图像分类任务来进行介绍，如何使用自定义DataBuilder完成联邦学习

## 环境设置

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address="local", log_to_driver=False)
alice, bob ,charlie = sf.PYU('alice'), sf.PYU('bob') , sf.PYU('charlie')

2023-04-17 15:14:51,955	INFO worker.py:1538 -- Started a local Ray instance.


## 接口介绍

我们在SecretFlow的`FLModel`中支持了自定义DataBuilder的读取方式，可以方便用户根据需求更灵活的处理数据输入。
下面我们以一个例子来展示下，如何使用自定义DataBuilder来进行联邦模型训练。

使用DataBuilder的步骤：
1. 使用单机版本pytorch引擎进行开发，完成pytorch下构建DataLoader的DataBuilder函数。*注：dataset_builder函数需要传入stage参数*
2. 将各方的DataBuilder函数进行wrap，得到create_dataset_builder
3. 构造data_builder_dict [PYU,dataset_builder]
4. 将得到的data_builder_dict作为参数传入`fit`函数的`dataset_builder`。此时`x`参数位置传入dataset_builder中需要的输入。（eg:本例中传入的输入是实际使用的图像路径）


在FLModel中使用DataBuilder需要预先定义databuilder dict。需要能够返回`tf.dataset`和`steps_per_epoch`。而且各方返回的steps_per_epoch必须保持一致。
```python
data_builder_dict = 
        {
            alice: create_alice_dataset_builder(
                batch_size=32,
            ), # create_alice_dataset_builder must return (Dataset, steps_per_epoch)
            bob: create_bob_dataset_builder(
                batch_size=32,
            ), # create_bob_dataset_builder must return (Dataset, steps_per_epochstep_per_epochs)
        }

```

## 下载数据

Flower数据集介绍：Flower数据集是一个包含了5种花卉（郁金香、黄水仙、鸢尾花、百合、向日葵）共计4323张彩色图片的数据集。每种花卉都有多个角度和不同光照下的图片，每张图片的分辨率为320x240。这个数据集常用于图像分类和机器学习算法的训练与测试。数据集中每个类别的数量分别是：daisy（633），dandelion（898），rose（641），sunflower（699），tulip（852）  
  
下载地址: [http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz)
<img alt="flower_dataset_demo.png" src="resources/flower_dataset_demo.png" width="600">  



### 下载数据并解压

In [3]:
# 这里复用tf的接口下载图片，输出是一个文件夹，如下图所示
import tempfile
import tensorflow as tf

_temp_dir = tempfile.mkdtemp()
path_to_flower_dataset = tf.keras.utils.get_file(
    "flower_photos",
    "https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz",
    untar=True,
    cache_dir=_temp_dir,
)

Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz


# 接下来我们开始构造自定义DataBuilder

## 1. 使用单机引擎开发DataBuilder

我们在开发`DataBuilder`的时候可以自由的按照单机开发的逻辑即可。  
目的是构建一个`Torch`中`Dataloader`对象即可

In [4]:
import math

import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
# parameter
batch_size = 32
shuffle = True
random_seed = 1234
train_split = 0.8

# Define dataset
flower_transform = transforms.Compose(
    [
        transforms.Resize((180, 180)),
        transforms.ToTensor(),
    ]
)
flower_dataset = datasets.ImageFolder(path_to_flower_dataset, transform=flower_transform)
dataset_size = len(flower_dataset)
# Define sampler

indices = list(range(dataset_size))
if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
split = int(np.floor(train_split * dataset_size))
train_indices, val_indices = indices[:split], indices[split:]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

# Define databuilder
train_loader = DataLoader(
    flower_dataset, batch_size=batch_size, sampler=train_sampler
)
valid_loader = DataLoader(
    flower_dataset, batch_size=batch_size, sampler=valid_sampler
)

In [5]:
x,y = next(iter(train_loader))
print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")

x.shape = torch.Size([32, 3, 180, 180])
y.shape = torch.Size([32])


## 2. 将开发完成的DataBuilder进行包装(wrap)

我们开发好的DataBuilder在运行是需要分发到各个执行机器上去执行，为了序列化，我们需要把他们进行wrap。  
需要注意的是：
- FLModel要求DataBuilder的输入必须包含stage参数（stage="train）
- FLModel要求传入的DataBuilder需要返回两个结果（data_set，steps_per_epoch）**

In [6]:
def create_dataset_builder(
        batch_size=32,
        train_split=0.8,
        shuffle=True,
        random_seed=1234,
    ):
        def dataset_builder(x, stage="train"):
            """
            """
            import math

            import numpy as np
            from torch.utils.data import DataLoader
            from torch.utils.data.sampler import SubsetRandomSampler
            from torchvision import datasets, transforms

            # Define dataset
            flower_transform = transforms.Compose(
                [
                    transforms.Resize((180, 180)),
                    transforms.ToTensor(),
                ]
            )
            flower_dataset = datasets.ImageFolder(x, transform=flower_transform)
            dataset_size = len(flower_dataset)
            # Define sampler

            indices = list(range(dataset_size))
            if shuffle:
                np.random.seed(random_seed)
                np.random.shuffle(indices)
            split = int(np.floor(train_split * dataset_size))
            train_indices, val_indices = indices[:split], indices[split:]
            train_sampler = SubsetRandomSampler(train_indices)
            valid_sampler = SubsetRandomSampler(val_indices)

            # Define databuilder
            train_loader = DataLoader(
                flower_dataset, batch_size=batch_size, sampler=train_sampler
            )
            valid_loader = DataLoader(
                flower_dataset, batch_size=batch_size, sampler=valid_sampler
            )

            # Return
            if stage == "train":
                train_step_per_epoch = math.ceil(split / batch_size)
                
                return train_loader, train_step_per_epoch
            elif stage == "eval":
                eval_step_per_epoch = math.ceil((dataset_size - split) / batch_size)
                return valid_loader, eval_step_per_epoch

        return dataset_builder

## 3. 构建dataset_builder_dict

In [7]:
#prepare dataset dict
data_builder_dict = {
    alice: create_dataset_builder(
        batch_size=32,
        train_split=0.8,
        shuffle=False,
        random_seed=1234,
    ),
    bob: create_dataset_builder(
        batch_size=32,
        train_split=0.8,
        shuffle=False,
        random_seed=1234,
    ),
}

## 4. 得到dataset_builder_dict我们就可以使用它进行联邦训练了

# 接下来我们定义一个Torch后端的FLModel来进行训练

### 定义模型结构

In [8]:
from secretflow.ml.nn.utils import BaseModule

class ConvRGBNet(BaseModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.network = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(16 * 45 * 45, 128),
            nn.ReLU(),
            nn.Linear(128, 5),
        )

    def forward(self, xb):
        return self.network(xb)

In [9]:
from secretflow.ml.nn import FLModel
from secretflow.security.aggregation import SecureAggregator
from torch import nn, optim
from torchmetrics import Accuracy, Precision
from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper
from secretflow.ml.nn.utils import TorchModel


In [10]:
device_list = [alice, bob]
aggregator = SecureAggregator(charlie,[alice,bob])
# prepare model
num_classes = 5

input_shape = (180, 180, 3)
# torch model
loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-3)
model_def = TorchModel(
    model_fn=ConvRGBNet,
    loss_fn=loss_fn,
    optim_fn=optim_fn,
    metrics=[
        metric_wrapper(
            Accuracy, task="multiclass", num_classes=num_classes, average='micro'
        ),
        metric_wrapper(
            Precision, task="multiclass", num_classes=num_classes, average='micro'
        ),
    ],
)

fed_model = FLModel(
    device_list=device_list,
    model=model_def,
    aggregator=aggregator,
    backend="torch",
    strategy="fed_avg_w",
    random_seed=1234,
)

INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party bob.


我们构造好的dataset builder的输入是图像数据集的路径，所以这里需要将输入的数据设置为一个`Dict`
```python
data = {
    alice: folder_path_of_alice,
    bob: folder_path_of_bob
}
```

In [11]:
data={
        alice: path_to_flower_dataset,
        bob: path_to_flower_dataset,
    }
history = fed_model.fit(
    data,
    None,
    validation_data=data,
    epochs=5,
    batch_size=32,
    aggregate_freq=2,
    sampler_method="batch",
    random_seed=1234,
    dp_spent_step_freq=1,
    dataset_builder=data_builder_dict
)

INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7ff4efc87af0>, 'x': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 5, 'verbose': 1, 'callbacks': None, 'validation_data': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {alice: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff600fb7ee0>, bob: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff6007148b0>}}
100%|██████████| 30/30 [00:32<00:00,  1.08s/it, epoch: 1/5 -  multiclassaccuracy:0.3760416805744171  m