# Homo NN 自定义数据集: 图像+文本特征

在该版本中 整个nn的架构有很大的调整，nn模块开发了dataset模块，旨在提供数据集和模型的自定义功能，上个教程中，我们介绍了模型自定义，在这个教程中， 我们将会介绍如何进行Dataset的自定义，我们会以一个具有混合类型的数据集为例

## 开发一个混合了图片与文本的数据集

该任务中我们使用flicker 的 image captions数据集，由于原数据集较大，我们将其缩减，位置位于examples/data/flicker_toy_data

我们粗略的将其分为两类，一类是位于野外的照片，以及对照片的文本描述；另一类是处于城市的照片，以及对照片的文本描述，
我们以图片和文本作为特征，开发一个分类模型，第一步，就是开发一个数据集

首先我们来看看原始数据

### 原始数据

In [1]:
import pandas as pd
df_text = pd.read_csv('../examples/data/flicker_toy_data/text.csv')

In [2]:
df_text

Unnamed: 0,id,text
0,1022454428_b6b660a67b,A man and woman care for an infant along the s...
1,103195344_5d2dc613a3,A man sitting in front of a metal sculpture in...
2,1055753357_4fa3d8d693,Two construction workers take a seat on a stee...
3,1124448967_2221af8dc5,Man relaxing in a folding chair on the street .
4,1131804997_177c3c0640,Two men with backpacks are sitting on cardboar...
...,...,...
210,78984436_ad96eaa802,Two German shepherd dogs are running with an o...
211,84713990_d3f3cef78b,Several people are rafting down a choppy river...
212,90011335_cfdf9674c2,A white boat on glassy water with mountains in...
213,96973080_783e375945,A dog runs through the snow .


In [3]:
# 文件夹下有两类图片
import os
os.listdir('../examples/data/flicker_toy_data/flicker/images/')

['wild', 'city']

In [4]:
# 每个图片都有其id，对应了df_text里的每一个id
os.listdir('../examples/data/flicker_toy_data/flicker/images/city')

['1332492622_8c66992b62.jpg',
 '241345639_1556a883b1.jpg',
 '47871819_db55ac4699.jpg',
 '617038406_4092ee91dd.jpg',
 '616045808_0286d0574b.jpg',
 '1055753357_4fa3d8d693.jpg',
 '211295363_49010ca38d.jpg',
 '635444010_bd81c89ab7.jpg',
 '309771854_952aabe3cc.jpg',
 '1167908324_8caab45e15.jpg',
 '247619370_a01fb21dd3.jpg',
 '241345323_f53eb5eec4.jpg',
 '191003285_edd8d0cf58.jpg',
 '1143373711_2e90b7b799.jpg',
 '615916000_5044047d71.jpg',
 '241345656_861aacefde.jpg',
 '241346105_c1c860db0d.jpg',
 '489134459_1b3f46fc03.jpg',
 '1355945307_f9e01a9a05.jpg',
 '1984936420_3f3102132b.jpg',
 '241345811_46b5f157d4.jpg',
 '1342766791_1e72f92455.jpg',
 '297285273_688e44c014.jpg',
 '1346051107_9cdc14e070.jpg',
 '191003284_1025b0fb7d.jpg',
 '1467533293_a2656cc000.jpg',
 '1131804997_177c3c0640.jpg',
 '241345905_5826a72da1.jpg',
 '1022454428_b6b660a67b.jpg',
 '1426014905_da60d72957.jpg',
 '317641829_ab2607a6c0.jpg',
 '2073174497_18b779999c.jpg',
 '535249787_0fcaa613a0.jpg',
 '2073964624_52da3a0fc4.jpg',
 

### Dataset

Fate中nn.dataset下提供了一个Dataset基类，其要求实现的接口除了load接口外，与Pytorch Dataset完全一致。基于Dataset实现的数据集类，将其更新到nn.dataset模块中，FATE在运行时便可根据参数导入您自定义的数据集，进行训练，此处，我们实现一个MixFeatureDataset

这里使用便捷的jupyter notebook接口将开发好的数据集代码其更新到nn.dataset模块中

In [5]:
from pipeline.component.homo_nn import save_to_fate # 更新接口

In [6]:
%%save_to_fate dataset mix_feature_ds.py
import torch as t
import pandas as pd
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from transformers import BertTokenizerFast

class MixFeatureDataset(Dataset):
    
    def __init__(self, output_image_size):
        super(MixFeatureDataset, self).__init__() # 记得这个
        self.output_image_size = output_image_size
        self.image_folder = None
        self.text = None
        self.word_idx = None
        self.vocab_size = 0
        self.sample_ids = None
        
    # 需要实现的接口 load, load接受一个参数path, 算法运行时将会把path传给这个接口
    def load(self, path):
        # 处理图像数据集
        transformer = transforms.Compose([transforms.CenterCrop(size=self.output_image_size), transforms.ToTensor()])
        self.image_folder = ImageFolder(root=path+'/flicker/images', transform=transformer)
        
        # 处理文本数据集，将其符号化（tokenize)
        import os
        os.environ["TOKENIZERS_PARALLELISM"] = "false" # avoid tokenizer problem
        
        self.text = pd.read_csv(path+'/text.csv')
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # 用bert tokenizer
        text_list = list(self.text.text)
        self.word_idx = tokenizer(text_list, padding=True, return_tensors='pt',
                                  truncation=True, max_length=20)['input_ids']
        self.vocab_size = tokenizer.vocab_size
        
        # 保证image数据集图片的id能与文本的id对应上
        img_ids = [i[0].split('/')[-1].replace('.jpg', '') for i in self.image_folder.imgs]
        text_ids = list(self.text.id)
        assert img_ids == text_ids
        print('id match!')
        self.sample_ids = text_ids
    
    
    # 需要实现的接口1 len
    def __len__(self):
        return len(self.image_folder)
    
    # 需要实现的接口2 getitem, 返回（数据，label)
    def __getitem__(self, idx):
        img, label = self.image_folder[idx]
        text = self.word_idx[idx]
        return (img, text), t.tensor(label).type(t.float32)
    
    # 此接口可选，如果不实现，FATE将拿不到sample id，便自动生成
    def get_sample_ids(self,):
        return self.sample_ids

In [7]:
ds = MixFeatureDataset((224, 224)) # 图片规范的尺寸
ds.load('../examples/data/flicker_toy_data/')

id match!


In [8]:
ds[0]

((tensor([[[0.5059, 0.5176, 0.5137,  ..., 0.4941, 0.5020, 0.5059],
           [0.4980, 0.5020, 0.4980,  ..., 0.4824, 0.5020, 0.5059],
           [0.5059, 0.4863, 0.4902,  ..., 0.4980, 0.4980, 0.5137],
           ...,
           [0.7843, 0.7922, 0.7529,  ..., 0.1412, 0.2078, 0.2196],
           [0.9922, 0.9922, 0.9647,  ..., 0.1176, 0.0941, 0.1333],
           [0.9961, 0.9922, 1.0000,  ..., 0.1647, 0.1294, 0.1373]],
  
          [[0.5765, 0.5882, 0.5843,  ..., 0.5490, 0.5569, 0.5608],
           [0.5686, 0.5804, 0.5765,  ..., 0.5490, 0.5529, 0.5529],
           [0.5608, 0.5569, 0.5647,  ..., 0.5569, 0.5490, 0.5529],
           ...,
           [0.7961, 0.8039, 0.7490,  ..., 0.1373, 0.1882, 0.2000],
           [0.9961, 0.9961, 0.9608,  ..., 0.1137, 0.1137, 0.1529],
           [0.9922, 0.9922, 1.0000,  ..., 0.1608, 0.1059, 0.1216]],
  
          [[0.6235, 0.6353, 0.6314,  ..., 0.5922, 0.6000, 0.6118],
           [0.6078, 0.6235, 0.6196,  ..., 0.5804, 0.5882, 0.6000],
           [0.6039, 0.

### 测试一下能不能用

In [9]:
from torch.utils.data import DataLoader

for i in DataLoader(ds, batch_size=2):
    break

In [10]:
i

[[tensor([[[[0.5059, 0.5176, 0.5137,  ..., 0.4941, 0.5020, 0.5059],
            [0.4980, 0.5020, 0.4980,  ..., 0.4824, 0.5020, 0.5059],
            [0.5059, 0.4863, 0.4902,  ..., 0.4980, 0.4980, 0.5137],
            ...,
            [0.7843, 0.7922, 0.7529,  ..., 0.1412, 0.2078, 0.2196],
            [0.9922, 0.9922, 0.9647,  ..., 0.1176, 0.0941, 0.1333],
            [0.9961, 0.9922, 1.0000,  ..., 0.1647, 0.1294, 0.1373]],
  
           [[0.5765, 0.5882, 0.5843,  ..., 0.5490, 0.5569, 0.5608],
            [0.5686, 0.5804, 0.5765,  ..., 0.5490, 0.5529, 0.5529],
            [0.5608, 0.5569, 0.5647,  ..., 0.5569, 0.5490, 0.5529],
            ...,
            [0.7961, 0.8039, 0.7490,  ..., 0.1373, 0.1882, 0.2000],
            [0.9961, 0.9961, 0.9608,  ..., 0.1137, 0.1137, 0.1529],
            [0.9922, 0.9922, 1.0000,  ..., 0.1608, 0.1059, 0.1216]],
  
           [[0.6235, 0.6353, 0.6314,  ..., 0.5922, 0.6000, 0.6118],
            [0.6078, 0.6235, 0.6196,  ..., 0.5804, 0.5882, 0.6000],
      

nice! 至此我们已经有一个数据集了，我们再为它开发对应的模型，然后进行本地测试，如果可以跑通，就可以提交为联邦任务了 注意，这里为了演示方便 我们只使用一个数据集，在实际使用中，你可以基于某一方的数据集进行开发, 本地验证，然后联邦任务时再使用全部的数据

## 自定义一个模型

这里定义一个可以同时处理图像和文本的模型，保存到nn.model_zoo下, 模块名称为flicker_classifier

In [11]:
%%save_to_fate model flicker_classifier.py
import torch as t
from torch import nn

class FlickerClassifier(nn.Module):
    
    def __init__(self, vocab_size,word_embed_size=8):
        super(FlickerClassifier, self).__init__()
        
        # 图像部分
        self.cv_seq = t.nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
            nn.MaxPool2d(kernel_size=3),
            nn.Conv2d(in_channels=6, out_channels=6, kernel_size=3),
            nn.AvgPool2d(kernel_size=5)
        )
        self.fc = t.nn.Sequential(
            nn.Linear(1176, 32),
            nn.ReLU(),
            nn.Linear(32, 8)
        )
        # NLP部分
        self.word_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=word_embed_size)
        self.lstm_seq = nn.LSTM(input_size=word_embed_size, hidden_size=word_embed_size, batch_first=True)
        # 分类器
        self.classifier_seq = nn.Sequential(
            nn.ReLU(),
            nn.Linear(word_embed_size + 8, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        image_feat = x[0]
        word_feat = x[1]
        
        image_feat = self.fc(self.cv_seq(image_feat).flatten(start_dim=1))
        word_feat, _ = self.lstm_seq(self.word_embedding(x[1]))
        word_feat = word_feat.sum(dim=1)
        return self.classifier_seq(t.cat([image_feat, word_feat], axis=1)).flatten()
        

In [12]:
model = FlickerClassifier(ds.vocab_size, 8)

In [13]:
model(i[0])

tensor([0.6456, 0.7330], grad_fn=<ReshapeAliasBackward0>)

## 本地验证

我们本地测试下这一套（数据集+模型）能否跑通，本地验证情况下会略去联邦聚合的部分

In [14]:
import torch as t
from federatedml.nn.backend.utils.common import global_seed
from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer

global_seed(100)
ds = MixFeatureDataset((224, 224)) # 图片规范的尺寸
ds.load('../examples/data/flicker_toy_data/')
model = FlickerClassifier(ds.vocab_size, 8)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = t.nn.BCELoss()
trainer = FedAVGTrainer(epochs=5, batch_size=64, data_loader_worker=4)
trainer.local_mode()  # !!!
trainer.set_model(model)

id match!


In [15]:
trainer.train(ds, optimizer=optimizer, loss=loss)

epoch is 0
epoch loss is 0.846086848059366
epoch is 1
epoch loss is 0.7933921104253725
epoch is 2
epoch loss is 0.7724235889523529
epoch is 3
epoch loss is 0.7441925581111465
epoch is 4
epoch loss is 0.7270668562068495


## 提交Homo-NN

完成了本地验证，确认可以跑通后，我们可以通过pipeline提交模型了

In [16]:
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model

fate_torch_hook(t)


import os
# 绑定地址到fate name&namespace
fate_project_path = os.path.abspath('../')
host_0 = 10000
host_1 = 9999
pipeline = PipeLine().set_initiator(role='host', party_id=host_0).set_roles(host=[host_0, host_1],
                                                                            arbiter=[host_0])
data_0 = {"name": "flicker", "namespace": "experiment"}
# 为方便，本示例中两个client使用同一份数据集
data_path = fate_project_path + '/examples/data/flicker_toy_data'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)

{'namespace': 'experiment', 'table_name': 'flicker'}

In [17]:
# 定义reader
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='host', party_id=host_0).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host_1).component_param(table=data_0)

In [18]:
from pipeline.component.homo_nn import DatasetParam, TrainerParam  # 数据集的接口

model = t.nn.Sequential(
    t.nn.CustModel(module_name='flicker_classifier', class_name='FlickerClassifier' ,vocab_size=ds.vocab_size, word_embed_size=8) 
)

nn_component = HomoNN(name='nn_0',
                      model=model, # 模型
                      loss=t.nn.BCELoss(),
                      optimizer=t.optim.Adam(model.parameters(), lr=0.01),
                      dataset=DatasetParam(dataset_name='mix_feature_ds', output_image_size=(224, 224)),  # 使用自定义的dataset
                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=5, batch_size=64, data_loader_worker=4, validation_freqs=1,
                                           secure_aggregate=False),
                      torch_seed=100 # 全局随机种子
                      )

In [19]:
# 添加组件到pipeline，定义数据IO关系，提交即可
pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='binary'), data=Data(data=nn_component.output.data))

<pipeline.backend.pipeline.PipeLine at 0x7fa42bf1f580>

In [None]:
pipeline.compile()
pipeline.fit()

[32m2022-11-11 12:32:24.100[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202211111232233316510
[0m
[32m2022-11-11 12:32:24.112[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[0mm2022-11-11 12:32:25.171[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2022-11-11 12:32:25.172[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:01[0m
[32m2022-11-11 12:32:26.206[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2022-11-11 12:32:27.243[0m | [1mI

[32m2022-11-11 12:33:05.059[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:00:40[0m
[32m2022-11-11 12:33:06.165[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:00:42[0m
[32m2022-11-11 12:33:07.310[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:00:43[0m
[32m2022-11-11 12:33:08.498[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:00:44[0m
[32m2022-11-11 12:33:09.619[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning c

[32m2022-11-11 12:33:50.540[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:01:26[0m
[32m2022-11-11 12:33:51.727[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:01:27[0m
[32m2022-11-11 12:33:52.834[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:01:28[0m
[32m2022-11-11 12:33:53.943[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:01:29[0m
[32m2022-11-11 12:33:55.146[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning c

[32m2022-11-11 12:34:36.246[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:02:12[0m
[32m2022-11-11 12:34:37.362[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:02:13[0m
[32m2022-11-11 12:34:38.480[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:02:14[0m
[32m2022-11-11 12:34:39.617[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:02:15[0m
[32m2022-11-11 12:34:40.716[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning c

[32m2022-11-11 12:35:22.097[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:02:57[0m
[32m2022-11-11 12:35:23.199[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:02:59[0m
[32m2022-11-11 12:35:24.306[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:03:00[0m
[32m2022-11-11 12:35:25.451[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component nn_0, time elapse: 0:03:01[0m
[32m2022-11-11 12:35:26.574[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning c

完成训练！你可以在fateboard中查看结果, 或者:

In [21]:
pipeline.get_component('nn_0').get_output_data()

Unnamed: 0,id,label,predict_result,predict_score,predict_detail,type
0,667626_18933d713e,1.0,0,0.3941565454006195,"{'0': 0.6058434545993805, '1': 0.3941565454006...",train
1,3637013_c675de7705,1.0,1,0.5649510622024536,"{'0': 0.4350489377975464, '1': 0.5649510622024...",train
2,10815824_2997e03d76,1.0,1,0.575126051902771,"{'0': 0.424873948097229, '1': 0.575126051902771}",train
3,17273391_55cfc7d3d4,1.0,1,0.6380451321601868,"{'0': 0.36195486783981323, '1': 0.638045132160...",train
4,19212715_20476497a3,1.0,0,0.42395642399787903,"{'0': 0.576043576002121, '1': 0.42395642399787...",train
...,...,...,...,...,...,...
210,3693961165_9d6c333d5b,1.0,1,0.5886315107345581,"{'0': 0.4113684892654419, '1': 0.5886315107345...",train
211,3715559023_70c41b31c7,1.0,1,0.5360317230224609,"{'0': 0.46396827697753906, '1': 0.536031723022...",train
212,3719461451_07de35af3a,1.0,1,0.5874216556549072,"{'0': 0.4125783443450928, '1': 0.5874216556549...",train
213,3745451546_fc8ec70cbd,1.0,1,0.5361397862434387,"{'0': 0.4638602137565613, '1': 0.5361397862434...",train
