# Customize Client training Scripts based on Tasks


In previous section, we converted the client training code to federated learning training. But we simply the process a bit. There we assume we always training, regardless what tasks the server issues to the clients. 

What if there are many tasks ? Client should take different actions based on the different tasks. 

In Flare's Client API, by detault, we will issue three different tasks: "train", "evalute" and "submit_model"

These three tasks can be checked by 

```

flare.is_train()

flare.is_evaluate()

flare.is_submit_model()

```

So we need to motify our existing training code to have both training and evaluation logics

## Training logics changes

Besides the training logics we have seen before. We also need to evaluate and obtain the accuracy of the trainiing. 
here we perform two evaluates 

evaluate the local model: 

```
            # (5.2) evaluation on local trained model to save best model
            local_accuracy = evaluate(net.state_dict())


```

evalute the global model received 

```
     # (5.3) evaluate on received model for model selection
            accuracy = evaluate(input_model.params)
```

Then add the global model accuracy into the metrics parameter of the FLModel before send it back to server. 

```
 output_model = flare.FLModel(
                params=net.cpu().state_dict(),
                metrics={"accuracy": accuracy},
                meta={"NUM_STEPS_CURRENT_ROUND": steps},
            )
```


The newly added training logics is like this. 

>Note: the evaluate() function will discussed next


```

 

            # (5.2) evaluation on local trained model to save best model
            local_accuracy = evaluate(net.state_dict())
            print(f"({client_id}) Evaluating local trained model. Accuracy on the 10000 test images: {local_accuracy}")
            if local_accuracy > best_accuracy:
                best_accuracy = local_accuracy
                torch.save(net.state_dict(), model_path)

            # (5.3) evaluate on received model for model selection
            accuracy = evaluate(input_model.params)
            print(
                f"({client_id}) Evaluating received model for model selection. Accuracy on the 10000 test images: {accuracy}"
            )

            # (5.4) construct trained FL model
            output_model = flare.FLModel(
                params=net.cpu().state_dict(),
                metrics={"accuracy": accuracy},
                meta={"NUM_STEPS_CURRENT_ROUND": steps},
            )

            # (5.5) send model back to NVFlare
            flare.send(output_model)

```

## Evaluate functions

The evaluate() functions requires test data, it is a nested inner evaluation that can directly use the testloader. 
The return value is accuracy percentage. 


```

    # wraps evaluation logic into a method to re-use for
    #       evaluation on both trained and received model
    def evaluate(input_weights):
        net = Net()
        net.load_state_dict(input_weights)
        # (optional) use GPU to speed things up
        net.to(DEVICE)

        correct = 0
        total = 0
        # since we're not training, we don't need to calculate the gradients for our outputs
        with torch.no_grad():
            for data in testloader:

                # (optional) use GPU to speed things up
                images, labels = data[0].to(DEVICE), data[1].to(DEVICE)

                # calculate outputs by running images through the network
                outputs = net(images)

                # the class with the highest energy is what we choose as prediction

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        return 100 * correct // total

```


The overall logics becomes

```
if flare.is_training(): 
    traing and evaluate metrics
    send model and merics back

elif flare.is_evaluate():
    # evaluate only, this can be used for cross-site evaluation
    evaluate()
    send the model and metrics back 

elif flare.is_submit_model()
    
    # expecting client submit best model 
    load and set best model 

```

Please take a look at the [client.py](./code/src/client.py)
