#  Visual Transformer Training with Parameter Efficient methods in FATE-ViT

In this tutorial, we train a federated ViT on the CIFAR-10 dataset.

For more details of FATE-LLM dataset setting, we recommend that you read through these tutorials first: [NN Dataset Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-your-Dataset.ipynb), [Some Built-In Dataset](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Introduce-Built-In-Dataset.ipynb),

### Use PELLM Model in FATE with CustModel

In this [Model Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-Model.ipynb) tutorial, we demonstrate how to employ the t.nn.CustomModel class in fate_torch to parse a model's structure and submit it to a federated learning task. The CustomModel automatically imports the model class from the model_zoo and initializes the models with the parameters provided. Since these language models are built-in, we can directly use them in the CustomModel and easily add a classifier head to address the classification task at hand：

In [None]:
import torch as t
from pipeline import fate_torch_hook
from pipeline.component.nn import save_to_fate_llm
fate_torch_hook(t)

In [None]:
%%save_to_fate_llm model sigmoid.py

import torch as t

class Sigmoid(t.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.sigmoid = t.nn.Sigmoid()
        
    def forward(self, x):
        return self.sigmoid(x.logits)

## Submit Federated Task
Once you have successfully completed local testing, We can submit a task to FATE. Please notice that this tutorial is ran on a standalone version. **Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to bind the data with the corresponding name&namespace on each machine.**

In this example we load pretrained weights for gpt2 model.

In [None]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels=1000)

In [None]:
print(model)

In [None]:
from fate_llm.dataset.image_tokenizer import TokenizerImageDataset

In [None]:
test = TokenizerImageDataset()

fate_path = '/data/projects/fate'
path=fate_path + '/examples/data/cifar10/test'

test.load(path)

In [None]:
len(test)

In [None]:
import torch as t
import os
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.component.homo_nn import DatasetParam, TrainerParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader
from pipeline.interface import Data
from transformers import ViTConfig


fate_torch_hook(t)


import os
fate_project_path = '/data/projects/fate'
guest = 9999
host = 9999

pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
                                                                            arbiter=host)
data_0 = {"name": "cifar10", "namespace": "experiment"}
data_path = fate_project_path + '/examples/data/cifar10/train'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)

reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_0)

#reader_1 = Reader(name="reader_1")
#reader_1.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
#reader_1.get_party_instance(role='host', party_id=host).component_param(table=data_0)
## Add your pretriained model path here, will load model&tokenizer from this path


## LoraConfig
from peft import LoraConfig, TaskType
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
"""
    LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,
    #target_modules=['c_attn']
)
"""


model_path = 'google/vit-base-patch16-224'
model = t.nn.Sequential(
    t.nn.CustModel(module_name='pellm.vit', class_name='vit', pretrained_path=model_path,
                   peft_config=lora_config.to_dict(), peft_type="LoraConfig", num_labels=1000,  pad_token_id=50256),
    t.nn.CustModel(module_name='sigmoid', class_name='Sigmoid')
)

# DatasetParam
dataset_param = DatasetParam(dataset_name='image_tokenizer') #
#DatasetParam(dataset_name='nlp_tokenizer',text_max_length=128, tokenizer_name_or_path=model_path, 
#                             padding_side="left", return_input_ids=False, pad_token='<|endoftext|>')
# TrainerParam
trainer_param = TrainerParam(trainer_name='fedavg_vit_trainer', epochs=1, batch_size=8,
                             data_loader_worker=1)

nn_component = HomoNN(name='nn_0', model=model)

# set parameter for client 1
nn_component.get_party_instance(role='guest', party_id=guest).component_param(
    loss=t.nn.CrossEntropyLoss(),
    optimizer = t.optim.Adam(lr=0.0001, eps=1e-8),
    dataset=dataset_param,       
    trainer=trainer_param,
    torch_seed=100 
)

# set parameter for client 2
nn_component.get_party_instance(role='host', party_id=host).component_param(
    loss=t.nn.CrossEntropyLoss(),
    optimizer = t.optim.Adam(lr=0.0001, eps=1e-8),
    dataset=dataset_param,       
    trainer=trainer_param,
    torch_seed=100 
)

# set parameter for server
nn_component.get_party_instance(role='arbiter', party_id=guest).component_param(    
    trainer=trainer_param
)

pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.compile()

pipeline.fit()

You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here.